diff --git a/Cargo.lock b/Cargo.lock index e2c7a049c..307f8b1ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,17 +2,26 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + [[package]] name = "adler2" -version = "2.0.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "aho-corasick" -version = "1.1.4" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -35,9 +44,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.21" +version = "0.6.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" dependencies = [ "anstyle", "anstyle-parse", @@ -50,9 +59,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.14" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" [[package]] name = "anstyle-parse" @@ -65,35 +74,35 @@ dependencies = [ [[package]] name = "anstyle-query" -version = "1.1.5" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" dependencies = [ - "windows-sys", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.11" +version = "3.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.102" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "async-trait" -version = "0.1.89" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", @@ -102,9 +111,24 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.5.0" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "backtrace" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets 0.52.6", +] [[package]] name = "base64" @@ -133,12 +157,12 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.72.1" +version = "0.72.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +checksum = "4f72209734318d0b619a5e0f5129918b848c416e122a3c4ce054e03cb87b726f" dependencies = [ "annotate-snippets", - "bitflags 2.11.0", + "bitflags 2.10.0", "cexpr", "clang-sys", "itertools", @@ -178,9 +202,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.11.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" [[package]] name = "block-buffer" @@ -193,9 +217,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" [[package]] name = "bzip2" @@ -218,20 +242,20 @@ dependencies = [ [[package]] name = "caps" -version = "0.5.6" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ddba47aba30b6a889298ad0109c3b8dcb0e8fc993b459daa7067d46f865e0" +checksum = "190baaad529bcfbde9e1a19022c42781bdb6ff9de25721abdb8fd98c0807730b" dependencies = [ "libc", + "thiserror 1.0.69", ] [[package]] name = "cc" -version = "1.2.57" +version = "1.2.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" +checksum = "956a5e21988b87f372569b66183b78babf23ebc2e744b733e4350a752c4dafac" dependencies = [ - "find-msvc-tools", "jobserver", "libc", "shlex", @@ -258,9 +282,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.4" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cfg_aliases" @@ -281,9 +305,9 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.5" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] name = "convert_case" @@ -311,9 +335,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.5.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] @@ -335,9 +359,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" -version = "0.1.7" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", "typenum", @@ -361,9 +385,9 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "env_filter" -version = "1.0.0" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -371,9 +395,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.9" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -395,31 +419,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + [[package]] name = "filetime" -version = "0.2.27" +version = "0.2.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" dependencies = [ "cfg-if", "libc", "libredox", + "windows-sys 0.60.2", ] -[[package]] -name = "find-msvc-tools" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" - [[package]] name = "flate2" -version = "1.1.9" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" dependencies = [ "crc32fast", "miniz_oxide", @@ -427,9 +452,9 @@ dependencies = [ [[package]] name = "foldhash" -version = "0.2.0" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "generic-array" @@ -443,33 +468,45 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.4" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", "libc", "r-efi", - "wasip2", + "wasi", ] +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + [[package]] name = "glob" -version = "0.3.3" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "hashbrown" -version = "0.16.1" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" dependencies = [ "allocator-api2", "equivalent", "foldhash", ] +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + [[package]] name = "heck" version = "0.5.0" @@ -478,9 +515,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "imago" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e8e4b92aa0dd860579cfba776dbf0918a3a7ac5cb601af7d3fc835e71592a5b" +checksum = "ae7cfee876c698a1a2ed9c705ab18f21acbed82110f19b51cc458de73426fe2c" dependencies = [ "async-trait", "bincode", @@ -493,17 +530,17 @@ dependencies = [ "tokio", "tracing", "vm-memory", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] name = "indexmap" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.17.0", ] [[package]] @@ -514,43 +551,43 @@ checksum = "d8972d5be69940353d5347a1344cb375d9b457d6809b428b05bb1ca2fb9ce007" [[package]] name = "is_terminal_polyfill" -version = "1.70.2" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "itertools" -version = "0.13.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.18" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jiff" -version = "0.2.23" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +checksum = "a194df1107f33c79f4f93d02c80798520551949d59dfad22b6157048a88cca93" dependencies = [ "jiff-static", "log", "portable-atomic", "portable-atomic-util", - "serde_core", + "serde", ] [[package]] name = "jiff-static" -version = "0.2.23" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +checksum = "6c6e1db7ed32c6c71b759497fae34bf7933636f75a251b9e736555da426f6442" dependencies = [ "proc-macro2", "quote", @@ -559,9 +596,9 @@ dependencies = [ [[package]] name = "jobserver" -version = "0.1.34" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ "getrandom", "libc", @@ -569,9 +606,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.91" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ "once_cell", "wasm-bindgen", @@ -670,7 +707,7 @@ name = "krun-display" version = "0.1.0" dependencies = [ "bindgen", - "bitflags 2.11.0", + "bitflags 2.10.0", "log", "static_assertions", "thiserror 2.0.18", @@ -691,7 +728,7 @@ name = "krun-input" version = "0.1.0" dependencies = [ "bindgen", - "bitflags 2.11.0", + "bitflags 2.10.0", "libc", "log", "static_assertions", @@ -756,9 +793,10 @@ name = "krun-vmm" version = "0.1.0-1.18.0" dependencies = [ "bitfield", - "bitflags 2.11.0", + "bitflags 2.10.0", "bzip2", "crossbeam-channel", + "env_logger", "flate2", "iocuddle", "kbs-types", @@ -778,30 +816,36 @@ dependencies = [ "linux-loader", "log", "nix 0.30.1", + "rand", "serde", "serde_json", "tdx", + "tempfile", + "thiserror 2.0.18", + "uds_windows", "vm-memory", "vmm-sys-util 0.14.0", + "windows-sys 0.61.2", + "zerocopy", "zstd", ] [[package]] name = "kvm-bindings" -version = "0.12.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a537873e15e8daabb416667e606d9b0abc2a8fb9a45bd5853b888ae0ead82f9" +checksum = "d4b153a59bb3ca930ff8148655b2ef68c34259a623ae08cf2fb9b570b2e45363" dependencies = [ "vmm-sys-util 0.14.0", ] [[package]] name = "kvm-ioctls" -version = "0.22.1" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c8f7370330b4f57981e300fa39b02088f2f2a5c2d0f1f994e8090589619c56d" +checksum = "b702df98508cb63ad89dd9beb9f6409761b30edca10d48e57941d3f11513a006" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "kvm-bindings", "libc", "vmm-sys-util 0.14.0", @@ -809,9 +853,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.183" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libkrun" @@ -840,23 +884,22 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.9" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-link", + "windows-targets 0.53.5", ] [[package]] name = "libredox" -version = "0.1.14" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "libc", - "plain", "redox_syscall", ] @@ -866,7 +909,7 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6b8cfa2a7656627b4c92c6b9ef929433acd673d5ab3708cda1b18478ac00df4" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "cc", "convert_case", "cookie-factory", @@ -899,30 +942,30 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.12.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" [[package]] name = "log" -version = "0.4.29" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "lru" -version = "0.16.3" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dc47f592c06f33f8e3aea9591776ec7c9f9e4124778ff8a3c3b87159f7e593" +checksum = "9f8cc7106155f10bdf99a6f379688f543ad6596a415375b36a59a054ceda1198" dependencies = [ - "hashbrown", + "hashbrown 0.15.3", ] [[package]] name = "memchr" -version = "2.8.0" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memoffset" @@ -950,9 +993,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.9" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", "simd-adler32", @@ -964,7 +1007,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b5b539a76e3f555fb143c3e67d5e05fa1d5fece02a515f6ecf41b3f1a081f58" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "libc", "nix 0.26.4", "rand", @@ -977,7 +1020,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6436c562bcdb6f192e0e59f627bff5b0b88f2e1c48264079f4f1d6da42bec2d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "libc", "nix 0.26.4", "vsock", @@ -998,11 +1041,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.30.1" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "cfg-if", "cfg_aliases", "libc", @@ -1011,11 +1054,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.31.2" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "cfg-if", "cfg_aliases", "libc", @@ -1041,17 +1084,26 @@ dependencies = [ "memchr", ] +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" -version = "1.21.4" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "once_cell_polyfill" -version = "1.70.2" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" [[package]] name = "page_size" @@ -1065,9 +1117,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.17" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -1082,7 +1134,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9688b89abf11d756499f7c6190711d6dbe5a3acdb30c8fbf001d6596d06a8d44" dependencies = [ "anyhow", - "bitflags 2.11.0", + "bitflags 2.10.0", "libc", "libspa", "libspa-sys", @@ -1109,23 +1161,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" -[[package]] -name = "plain" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" - [[package]] name = "portable-atomic" -version = "1.13.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "portable-atomic-util" -version = "0.2.6" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" dependencies = [ "portable-atomic", ] @@ -1141,27 +1187,27 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.106" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.45" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] [[package]] name = "r-efi" -version = "5.3.0" +version = "5.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" [[package]] name = "rand" @@ -1185,27 +1231,27 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.9.5" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ "getrandom", ] [[package]] name = "redox_syscall" -version = "0.7.3" +version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", ] [[package]] name = "regex" -version = "1.12.3" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -1215,9 +1261,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.14" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -1226,9 +1272,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "remain" @@ -1241,6 +1287,12 @@ dependencies = [ "syn", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -1258,28 +1310,34 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.4" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] name = "rustversion" -version = "1.0.22" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" + +[[package]] +name = "ryu" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "semver" -version = "1.0.27" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" @@ -1313,15 +1371,14 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", + "ryu", "serde", - "serde_core", - "zmij", ] [[package]] @@ -1372,9 +1429,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.8" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" [[package]] name = "sm3" @@ -1387,9 +1444,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.15.1" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "static_assertions" @@ -1420,9 +1477,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.117" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -1444,9 +1501,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.45" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" dependencies = [ "filetime", "libc", @@ -1465,7 +1522,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad59e5bf374211a1fdd8e7439a07d5a5e617fe97f5cf21d03bcd1bf8c82b73af" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.10.0", "iocuddle", "kvm-bindings", "kvm-ioctls", @@ -1474,6 +1531,19 @@ dependencies = [ "vmm-sys-util 0.12.1", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -1516,10 +1586,11 @@ dependencies = [ [[package]] name = "tokio" -version = "1.50.0" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ + "backtrace", "pin-project-lite", ] @@ -1564,9 +1635,9 @@ checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" [[package]] name = "tracing" -version = "0.1.44" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -1575,9 +1646,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.31" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +checksum = "1b1ffbcf9c6f6b99d386e7444eb608ba646ae452a36b39737deb9663b610f662" dependencies = [ "proc-macro2", "quote", @@ -1586,24 +1657,35 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.36" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", ] [[package]] name = "typenum" -version = "1.19.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + +[[package]] +name = "uds_windows" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f6fb2847f6742cd76af783a2a2c49e9375d0a111c7bef6f71cd9e738c72d6e" +dependencies = [ + "memoffset 0.9.1", + "tempfile", + "windows-sys 0.61.2", +] [[package]] name = "unicode-ident" -version = "1.0.24" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-segmentation" @@ -1631,9 +1713,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.22.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "js-sys", "wasm-bindgen", @@ -1641,9 +1723,9 @@ dependencies = [ [[package]] name = "version-compare" -version = "0.2.1" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c2856837ef78f57382f06b2b8563a2f512f7185d732608fd9176cb3b8edf0e" +checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" @@ -1653,9 +1735,9 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "virtio-bindings" -version = "0.2.7" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091f1f09cfbf2a78563b562e7a949465cce1aef63b6065645188d995162f8868" +checksum = "804f498a26d5a63be7bbb8bdcd3869c3f286c4c4a17108905276454da0caf8cb" [[package]] name = "virtue" @@ -1671,9 +1753,9 @@ checksum = "7e21282841a059bb62627ce8441c491f09603622cd5a21c43bfedc85a2952f23" [[package]] name = "vm-memory" -version = "0.17.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39348a049689cabd3377cdd9182bf526ec76a6f823b79903896452e9d7a7380" +checksum = "48f1f33aee6ae648087fbed47c2944e2796f7877d4717a59edc8d7cb62f71061" dependencies = [ "libc", "thiserror 2.0.18", @@ -1702,41 +1784,54 @@ dependencies = [ [[package]] name = "vsock" -version = "0.5.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b82aeb12ad864eb8cd26a6c21175d0bdc66d398584ee6c93c76964c3bcfc78ff" +checksum = "4e8b4d00e672f147fc86a09738fadb1445bd1c0a40542378dfb82909deeee688" dependencies = [ "libc", - "nix 0.31.2", + "nix 0.29.0", ] [[package]] -name = "wasip2" -version = "1.0.2+wasi-0.2.9" +name = "wasi" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ - "wit-bindgen", + "wit-bindgen-rt", ] [[package]] name = "wasm-bindgen" -version = "0.2.114" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.114" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1744,22 +1839,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.114" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ - "bumpalo", "proc-macro2", "quote", "syn", + "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.114" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" dependencies = [ "unicode-ident", ] @@ -1792,6 +1887,24 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -1801,6 +1914,135 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + [[package]] name = "winnow" version = "1.0.2" @@ -1808,10 +2050,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" [[package]] -name = "wit-bindgen" -version = "0.51.0" +name = "wit-bindgen-rt" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags 2.10.0", +] [[package]] name = "xattr" @@ -1825,30 +2070,24 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.47" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.47" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", "syn", ] -[[package]] -name = "zmij" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" - [[package]] name = "zstd" version = "0.13.3" @@ -1869,9 +2108,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.16+zstd.1.5.7" +version = "2.0.15+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" dependencies = [ "cc", "pkg-config", diff --git a/include/libkrun.h b/include/libkrun.h index b8f8008a5..af0733e4c 100644 --- a/include/libkrun.h +++ b/include/libkrun.h @@ -462,6 +462,25 @@ int32_t krun_add_net_unixgram(uint32_t ctx_id, uint32_t features, uint32_t flags); +/** + * Disables automatic TSI networking for this context. + * + * When no virtio-net devices are added, libkrun normally enables the TSI backend. + * Call this function to keep the guest fully offline instead: no virtio-net + * device will be attached and TSI will remain disabled. + * + * Arguments: + * "ctx_id" - the configuration context ID. + * + * Notes: + * This function only affects the automatic fallback path used when no network + * devices are added. It should be called before krun_start_enter. + * + * Returns: + * Zero on success or a negative error number on failure. + */ +int32_t krun_disable_tsi(uint32_t ctx_id); + /** * Adds an independent virtio-net device with the tap backend. * Call to this function disables TSI backend. diff --git a/src/devices/src/virtio/vsock/device.rs b/src/devices/src/virtio/vsock/device.rs index a61043b4c..f57208798 100644 --- a/src/devices/src/virtio/vsock/device.rs +++ b/src/devices/src/virtio/vsock/device.rs @@ -78,6 +78,10 @@ impl Vsock { self.cid } + pub fn enable_tsi(&self) -> bool { + self.muxer.enable_tsi() + } + /// Walk the driver-provided RX queue buffers and attempt to fill them up with any data that we /// have pending. Return `true` if descriptors have been added to the used ring, and `false` /// otherwise. diff --git a/src/devices/src/virtio/vsock/muxer.rs b/src/devices/src/virtio/vsock/muxer.rs index f4c10247e..620c58fa1 100644 --- a/src/devices/src/virtio/vsock/muxer.rs +++ b/src/devices/src/virtio/vsock/muxer.rs @@ -132,6 +132,10 @@ impl VsockMuxer { } } + pub fn enable_tsi(&self) -> bool { + self.tsi_flags.tsi_enabled() + } + pub(crate) fn activate( &mut self, mem: GuestMemoryMmap, diff --git a/src/libkrun/Cargo.toml b/src/libkrun/Cargo.toml index dc9000916..eebb7bb1f 100644 --- a/src/libkrun/Cargo.toml +++ b/src/libkrun/Cargo.toml @@ -29,12 +29,15 @@ log = "0.4.0" once_cell = "1.4.1" krun_display = { package = "krun-display", version = "0.1.0", path = "../display", optional = true, features = ["bindgen_clang_runtime"] } krun_input = { package = "krun-input", version = "0.1.0", path = "../input", optional = true, features = ["bindgen_clang_runtime"] } +rand = "0.9.2" + +vmm = { package = "krun-vmm", version = "=0.1.0-1.18.0", path = "../vmm" } +# Unix-only internal crates (not used by Windows WHPX backend) +[target.'cfg(unix)'.dependencies] devices = { package = "krun-devices", version = "=0.1.0-1.18.0", path = "../devices" } polly = { package = "krun-polly", version = "=0.1.0-1.18.0", path = "../polly" } utils = { package = "krun-utils", version = "=0.1.0-1.18.0", path = "../utils" } -vmm = { package = "krun-vmm", version = "=0.1.0-1.18.0", path = "../vmm" } -rand = "0.9.2" [target.'cfg(target_os = "macos")'.dependencies] hvf = { package = "krun-hvf", version = "=0.1.0-1.18.0", path = "../hvf" } @@ -46,6 +49,11 @@ aws-nitro = { package = "krun-aws-nitro", version = "=0.1.0-1.18.0", path = "../ nitro-enclaves = { version = "0.5.0", optional = true } vm-memory = { version = "0.17", features = ["backend-mmap"] } +# Windows-only dependencies (WHPX C API) +[target.'cfg(target_os = "windows")'.dependencies] +env_logger = "0.11" +libc = ">=0.2.39" + [lib] name = "krun" crate-type = ["cdylib", "lib"] diff --git a/src/libkrun/src/lib.rs b/src/libkrun/src/lib.rs index 8acf6d205..83bb6dee9 100644 --- a/src/libkrun/src/lib.rs +++ b/src/libkrun/src/lib.rs @@ -1,477 +1,499 @@ #[macro_use] extern crate log; -use crossbeam_channel::unbounded; -#[cfg(feature = "blk")] -use devices::virtio::block::{ImageType, SyncMode}; -#[cfg(feature = "gpu")] -use devices::virtio::gpu::display::DisplayInfo; -#[cfg(feature = "net")] -use devices::virtio::net::device::VirtioNetBackend; -#[cfg(feature = "blk")] -use devices::virtio::CacheType; -use env_logger::{Env, Target}; -#[cfg(feature = "gpu")] -use krun_display::DisplayBackend; - -use libc::{c_char, c_int, size_t}; -use once_cell::sync::Lazy; -use polly::event_manager::EventManager; -#[cfg(all(feature = "blk", not(feature = "tee")))] -use rand::distr::{Alphanumeric, SampleString}; -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::convert::TryInto; -use std::env; -#[cfg(target_os = "linux")] -use std::ffi::CString; -use std::ffi::{c_void, CStr}; -use std::fs::File; -use std::io::IsTerminal; -#[cfg(target_os = "linux")] -use std::os::fd::AsRawFd; -use std::os::fd::{BorrowedFd, FromRawFd, RawFd}; -use std::path::PathBuf; -use std::slice; -use std::sync::atomic::{AtomicI32, Ordering}; -use std::sync::LazyLock; -use std::sync::Mutex; -use utils::eventfd::EventFd; -use vmm::resources::{ - DefaultVirtioConsoleConfig, PortConfig, SerialConsoleConfig, TsiFlags, VirtioConsoleConfigMode, - VmResources, VsockConfig, -}; -#[cfg(feature = "blk")] -use vmm::vmm_config::block::{BlockDeviceConfig, BlockRootConfig}; -#[cfg(not(feature = "tee"))] -use vmm::vmm_config::external_kernel::{ExternalKernel, KernelFormat}; -#[cfg(not(feature = "tee"))] -use vmm::vmm_config::firmware::FirmwareConfig; -#[cfg(not(feature = "tee"))] -use vmm::vmm_config::fs::FsDeviceConfig; -use vmm::vmm_config::kernel_bundle::KernelBundle; -#[cfg(feature = "tee")] -use vmm::vmm_config::kernel_bundle::{InitrdBundle, QbootBundle}; -use vmm::vmm_config::kernel_cmdline::{KernelCmdlineConfig, DEFAULT_KERNEL_CMDLINE}; -use vmm::vmm_config::machine_config::VmConfig; -#[cfg(feature = "net")] -use vmm::vmm_config::net::NetworkInterfaceConfig; -use vmm::vmm_config::vsock::VsockDeviceConfig; - -#[cfg(feature = "aws-nitro")] -use aws_nitro::enclave::NitroEnclave; - -#[cfg(feature = "gpu")] -use devices::virtio::display::{DisplayInfoEdid, PhysicalSize, MAX_DISPLAYS}; -#[cfg(feature = "input")] -use krun_input::{InputConfigBackend, InputEventProviderBackend}; - -// Value returned on success. We use libc's errors otherwise. -const KRUN_SUCCESS: i32 = 0; -// Maximum number of arguments/environment variables we allow -const MAX_ARGS: usize = 4096; - -// krunfw library name for each context -#[cfg(all(target_os = "linux", not(feature = "tee")))] -const KRUNFW_NAME: &str = "libkrunfw.so.5"; -#[cfg(all(target_os = "linux", feature = "amd-sev"))] -const KRUNFW_NAME: &str = "libkrunfw-sev.so.5"; -#[cfg(all(target_os = "linux", feature = "tdx"))] -const KRUNFW_NAME: &str = "libkrunfw-tdx.so.5"; -#[cfg(target_os = "macos")] -const KRUNFW_NAME: &str = "libkrunfw.5.dylib"; - -#[cfg(feature = "aws-nitro")] -static KRUN_NITRO_DEBUG: Mutex = Mutex::new(false); - -// Path to the init binary to be executed inside the VM. -const INIT_PATH: &str = "/init.krun"; - -static KRUNFW: LazyLock> = - LazyLock::new(|| unsafe { libloading::Library::new(KRUNFW_NAME).ok() }); - -pub struct KrunfwBindings { - get_kernel: libloading::Symbol< - 'static, - unsafe extern "C" fn(*mut u64, *mut u64, *mut size_t) -> *mut c_char, - >, - #[cfg(feature = "tee")] - get_initrd: libloading::Symbol<'static, unsafe extern "C" fn(*mut size_t) -> *mut c_char>, - #[cfg(feature = "tee")] - get_qboot: libloading::Symbol<'static, unsafe extern "C" fn(*mut size_t) -> *mut c_char>, -} - -impl KrunfwBindings { - fn load_bindings() -> Result { - let krunfw = match KRUNFW.as_ref() { - Some(krunfw) => krunfw, - None => return Err(libloading::Error::DlOpenUnknown), - }; - Ok(unsafe { - KrunfwBindings { - get_kernel: krunfw.get(b"krunfw_get_kernel")?, - #[cfg(feature = "tee")] - get_initrd: krunfw.get(b"krunfw_get_initrd")?, - #[cfg(feature = "tee")] - get_qboot: krunfw.get(b"krunfw_get_qboot")?, - } - }) - } - - pub fn new() -> Option { - Self::load_bindings().ok() - } -} - -#[derive(Clone)] -#[cfg(feature = "net")] -enum LegacyNetworkConfig { - VirtioNetPasst(RawFd), - VirtioNetGvproxy(PathBuf), -} - -#[derive(Default)] -struct ContextConfig { - krunfw: Option, - vmr: VmResources, - workdir: Option, - exec_path: Option, - env: Option, - args: Option, - rlimits: Option, - #[cfg(feature = "net")] - legacy_net_cfg: Option, - #[cfg(feature = "net")] - legacy_mac: Option<[u8; 6]>, - net_index: u8, - tsi_port_map: Option>, - vsock_config: VsockConfig, - #[cfg(feature = "blk")] - block_cfgs: Vec, +// On Windows, the entire C API is implemented in windows_api.rs, +// delegating to vmm::windows::* instead of the Unix VMM infrastructure. +#[cfg(target_os = "windows")] +mod windows_api; + +// ── Unix C API implementation ──────────────────────────────────────────────── +// On Windows, the entire C API is implemented in windows_api.rs, which +// delegates to vmm::windows::* directly. Everything below (imports, types, +// statics, and all #[no_mangle] krun_* functions) is the upstream Unix +// C API implementation — gated out on Windows by this single module cfg. +#[cfg(not(target_os = "windows"))] +mod unix_api { + + use crossbeam_channel::unbounded; #[cfg(feature = "blk")] - root_block_cfg: Option, + use devices::virtio::block::{ImageType, SyncMode}; + #[cfg(feature = "gpu")] + use devices::virtio::gpu::display::DisplayInfo; + #[cfg(feature = "net")] + use devices::virtio::net::device::VirtioNetBackend; #[cfg(feature = "blk")] - data_block_cfg: Option, + use devices::virtio::CacheType; + use env_logger::{Env, Target}; + #[cfg(feature = "gpu")] + use krun_display::DisplayBackend; + + use libc::{c_char, c_int, size_t}; + use once_cell::sync::Lazy; + use polly::event_manager::EventManager; + #[cfg(all(feature = "blk", not(feature = "tee")))] + use rand::distr::{Alphanumeric, SampleString}; + use std::collections::hash_map::Entry; + use std::collections::HashMap; + use std::convert::TryInto; + use std::env; + #[cfg(target_os = "linux")] + use std::ffi::CString; + use std::ffi::{c_void, CStr}; + use std::fs::File; + use std::io::IsTerminal; + #[cfg(target_os = "linux")] + use std::os::fd::AsRawFd; + use std::os::fd::{BorrowedFd, FromRawFd, RawFd}; + use std::path::PathBuf; + use std::slice; + use std::sync::atomic::{AtomicI32, Ordering}; + use std::sync::LazyLock; + use std::sync::Mutex; + use utils::eventfd::EventFd; + use vmm::resources::{ + DefaultVirtioConsoleConfig, PortConfig, SerialConsoleConfig, TsiFlags, + VirtioConsoleConfigMode, VmResources, VsockConfig, + }; #[cfg(feature = "blk")] - block_root: Option, + use vmm::vmm_config::block::{BlockDeviceConfig, BlockRootConfig}; + #[cfg(not(feature = "tee"))] + use vmm::vmm_config::external_kernel::{ExternalKernel, KernelFormat}; + #[cfg(not(feature = "tee"))] + use vmm::vmm_config::firmware::FirmwareConfig; + #[cfg(not(feature = "tee"))] + use vmm::vmm_config::fs::FsDeviceConfig; + use vmm::vmm_config::kernel_bundle::KernelBundle; #[cfg(feature = "tee")] - tee_config_file: Option, - unix_ipc_port_map: Option>, - shutdown_efd: Option, - gpu_virgl_flags: Option, - gpu_shm_size: Option, - enable_snd: bool, - console_output: Option, - vmm_uid: Option, - vmm_gid: Option, -} + use vmm::vmm_config::kernel_bundle::{InitrdBundle, QbootBundle}; + use vmm::vmm_config::kernel_cmdline::{KernelCmdlineConfig, DEFAULT_KERNEL_CMDLINE}; + use vmm::vmm_config::machine_config::VmConfig; + #[cfg(feature = "net")] + use vmm::vmm_config::net::NetworkInterfaceConfig; + use vmm::vmm_config::vsock::VsockDeviceConfig; -impl ContextConfig { - fn set_workdir(&mut self, workdir: String) { - self.workdir = Some(workdir); - } + #[cfg(feature = "aws-nitro")] + use aws_nitro::enclave::NitroEnclave; + + #[cfg(feature = "gpu")] + use devices::virtio::display::{DisplayInfoEdid, PhysicalSize, MAX_DISPLAYS}; + #[cfg(feature = "input")] + use krun_input::{InputConfigBackend, InputEventProviderBackend}; + + const KRUN_SUCCESS: i32 = 0; + const MAX_ARGS: usize = 4096; + + #[cfg(all(target_os = "linux", not(feature = "tee")))] + const KRUNFW_NAME: &str = "libkrunfw.so.5"; + #[cfg(all(target_os = "linux", feature = "amd-sev"))] + const KRUNFW_NAME: &str = "libkrunfw-sev.so.5"; + #[cfg(all(target_os = "linux", feature = "tdx"))] + const KRUNFW_NAME: &str = "libkrunfw-tdx.so.5"; + #[cfg(target_os = "macos")] + const KRUNFW_NAME: &str = "libkrunfw.5.dylib"; - fn get_workdir(&self) -> String { - match &self.workdir { - Some(workdir) => format!("KRUN_WORKDIR={workdir}"), - None => "".to_string(), - } - } + #[cfg(feature = "aws-nitro")] + static KRUN_NITRO_DEBUG: Mutex = Mutex::new(false); - fn set_exec_path(&mut self, exec_path: String) { - self.exec_path = Some(exec_path); + const INIT_PATH: &str = "/init.krun"; + + static KRUNFW: LazyLock> = + LazyLock::new(|| unsafe { libloading::Library::new(KRUNFW_NAME).ok() }); + + pub struct KrunfwBindings { + get_kernel: libloading::Symbol< + 'static, + unsafe extern "C" fn(*mut u64, *mut u64, *mut size_t) -> *mut c_char, + >, + #[cfg(feature = "tee")] + get_initrd: libloading::Symbol<'static, unsafe extern "C" fn(*mut size_t) -> *mut c_char>, + #[cfg(feature = "tee")] + get_qboot: libloading::Symbol<'static, unsafe extern "C" fn(*mut size_t) -> *mut c_char>, } - fn get_exec_path(&self) -> String { - match &self.exec_path { - Some(exec_path) => format!("KRUN_INIT={exec_path}"), - None => "".to_string(), + impl KrunfwBindings { + fn load_bindings() -> Result { + let krunfw = match KRUNFW.as_ref() { + Some(krunfw) => krunfw, + None => return Err(libloading::Error::DlOpenUnknown), + }; + Ok(unsafe { + KrunfwBindings { + get_kernel: krunfw.get(b"krunfw_get_kernel")?, + #[cfg(feature = "tee")] + get_initrd: krunfw.get(b"krunfw_get_initrd")?, + #[cfg(feature = "tee")] + get_qboot: krunfw.get(b"krunfw_get_qboot")?, + } + }) } - } - #[cfg(all(feature = "blk", not(feature = "tee")))] - fn set_block_root(&mut self, device: String, fstype: Option, options: Option) { - self.block_root = Some(BlockRootConfig { - device, - fstype, - options, - }); + pub fn new() -> Option { + Self::load_bindings().ok() + } } - fn get_block_root(&self) -> String { + #[derive(Clone)] + #[cfg(feature = "net")] + enum LegacyNetworkConfig { + VirtioNetPasst(RawFd), + VirtioNetGvproxy(PathBuf), + } + + #[derive(Default)] + struct ContextConfig { + krunfw: Option, + vmr: VmResources, + workdir: Option, + exec_path: Option, + env: Option, + args: Option, + rlimits: Option, + #[cfg(feature = "net")] + legacy_net_cfg: Option, + #[cfg(feature = "net")] + legacy_mac: Option<[u8; 6]>, + #[cfg(feature = "net")] + disable_tsi: bool, + net_index: u8, + tsi_port_map: Option>, + vsock_config: VsockConfig, #[cfg(feature = "blk")] - match &self.block_root { - Some(block_root) => { - let mut res = format!("KRUN_BLOCK_ROOT_DEVICE={}", block_root.device); - if let Some(fstype) = &block_root.fstype { - res += &format!(" KRUN_BLOCK_ROOT_FSTYPE={fstype}"); - } - if let Some(options) = &block_root.options { - res += &format!(" KRUN_BLOCK_ROOT_OPTIONS={options}"); - } - res + block_cfgs: Vec, + #[cfg(feature = "blk")] + root_block_cfg: Option, + #[cfg(feature = "blk")] + data_block_cfg: Option, + #[cfg(feature = "blk")] + block_root: Option, + #[cfg(feature = "tee")] + tee_config_file: Option, + unix_ipc_port_map: Option>, + shutdown_efd: Option, + gpu_virgl_flags: Option, + gpu_shm_size: Option, + enable_snd: bool, + console_output: Option, + vmm_uid: Option, + vmm_gid: Option, + } + + impl ContextConfig { + fn set_workdir(&mut self, workdir: String) { + self.workdir = Some(workdir); + } + + fn get_workdir(&self) -> String { + match &self.workdir { + Some(workdir) => format!("KRUN_WORKDIR={workdir}"), + None => "".to_string(), } - None => "".to_string(), } - #[cfg(not(feature = "blk"))] - "".to_string() - } - fn set_env(&mut self, env: String) { - self.env = Some(env); - } + fn set_exec_path(&mut self, exec_path: String) { + self.exec_path = Some(exec_path); + } - fn get_env(&self) -> String { - match &self.env { - Some(env) => env.clone(), - None => "".to_string(), + fn get_exec_path(&self) -> String { + match &self.exec_path { + Some(exec_path) => format!("KRUN_INIT={exec_path}"), + None => "".to_string(), + } } - } - fn set_args(&mut self, args: String) { - self.args = Some(args); - } + #[cfg(all(feature = "blk", not(feature = "tee")))] + fn set_block_root( + &mut self, + device: String, + fstype: Option, + options: Option, + ) { + self.block_root = Some(BlockRootConfig { + device, + fstype, + options, + }); + } - fn get_args(&self) -> String { - match &self.args { - Some(args) => args.clone(), - None => "".to_string(), + fn get_block_root(&self) -> String { + #[cfg(feature = "blk")] + match &self.block_root { + Some(block_root) => { + let mut res = format!("KRUN_BLOCK_ROOT_DEVICE={}", block_root.device); + if let Some(fstype) = &block_root.fstype { + res += &format!(" KRUN_BLOCK_ROOT_FSTYPE={fstype}"); + } + if let Some(options) = &block_root.options { + res += &format!(" KRUN_BLOCK_ROOT_OPTIONS={options}"); + } + res + } + None => "".to_string(), + } + #[cfg(not(feature = "blk"))] + "".to_string() } - } - fn set_rlimits(&mut self, rlimits: String) { - self.rlimits = Some(rlimits); - } + fn set_env(&mut self, env: String) { + self.env = Some(env); + } - fn get_rlimits(&self) -> String { - match &self.rlimits { - Some(rlimits) => format!("KRUN_RLIMITS={rlimits}"), - None => "".to_string(), + fn get_env(&self) -> String { + match &self.env { + Some(env) => env.clone(), + None => "".to_string(), + } } - } - #[cfg(feature = "blk")] - fn add_block_cfg(&mut self, block_cfg: BlockDeviceConfig) { - self.block_cfgs.push(block_cfg); - } + fn set_args(&mut self, args: String) { + self.args = Some(args); + } - #[cfg(feature = "blk")] - fn set_root_block_cfg(&mut self, block_cfg: BlockDeviceConfig) { - self.root_block_cfg = Some(block_cfg); - } + fn get_args(&self) -> String { + match &self.args { + Some(args) => args.clone(), + None => "".to_string(), + } + } - #[cfg(feature = "blk")] - fn set_data_block_cfg(&mut self, block_cfg: BlockDeviceConfig) { - self.data_block_cfg = Some(block_cfg); - } + fn set_rlimits(&mut self, rlimits: String) { + self.rlimits = Some(rlimits); + } - #[cfg(feature = "blk")] - fn get_block_cfg(&self) -> Vec { - // For backwards compat, when cfgs is empty (the new API is not used), this needs to be - // root and then data, in that order. Also for backwards compat, root/data are setters and - // need to discard redundant calls. So we have simple setters above and fix up here. - // - // When the new API is used, this is simpler. - if self.block_cfgs.is_empty() { - [&self.root_block_cfg, &self.data_block_cfg] - .into_iter() - .filter_map(|cfg| cfg.clone()) - .collect() - } else { - self.block_cfgs.clone() + fn get_rlimits(&self) -> String { + match &self.rlimits { + Some(rlimits) => format!("KRUN_RLIMITS={rlimits}"), + None => "".to_string(), + } } - } - #[cfg(feature = "net")] - fn set_net_mac(&mut self, mac: [u8; 6]) { - self.legacy_mac = Some(mac); - } + #[cfg(feature = "blk")] + fn add_block_cfg(&mut self, block_cfg: BlockDeviceConfig) { + self.block_cfgs.push(block_cfg); + } - fn set_port_map(&mut self, new_port_map: HashMap) -> Result<(), ()> { - if self.net_index != 0 { - return Err(()); + #[cfg(feature = "blk")] + fn set_root_block_cfg(&mut self, block_cfg: BlockDeviceConfig) { + self.root_block_cfg = Some(block_cfg); } - self.tsi_port_map.replace(new_port_map); - Ok(()) - } + #[cfg(feature = "blk")] + fn set_data_block_cfg(&mut self, block_cfg: BlockDeviceConfig) { + self.data_block_cfg = Some(block_cfg); + } - #[cfg(feature = "tee")] - fn set_tee_config_file(&mut self, filepath: PathBuf) { - self.tee_config_file = Some(filepath); - } + #[cfg(feature = "blk")] + fn get_block_cfg(&self) -> Vec { + // For backwards compat, when cfgs is empty (the new API is not used), this needs to be + // root and then data, in that order. Also for backwards compat, root/data are setters and + // need to discard redundant calls. So we have simple setters above and fix up here. + // + // When the new API is used, this is simpler. + if self.block_cfgs.is_empty() { + [&self.root_block_cfg, &self.data_block_cfg] + .into_iter() + .filter_map(|cfg| cfg.clone()) + .collect() + } else { + self.block_cfgs.clone() + } + } - #[cfg(feature = "tee")] - fn get_tee_config_file(&self) -> Option { - self.tee_config_file.clone() - } + #[cfg(feature = "net")] + fn set_net_mac(&mut self, mac: [u8; 6]) { + self.legacy_mac = Some(mac); + } - fn add_vsock_port(&mut self, port: u32, filepath: PathBuf, listen: bool) { - if let Some(ref mut map) = &mut self.unix_ipc_port_map { - map.insert(port, (filepath, listen)); - } else { - let mut map: HashMap = HashMap::new(); - map.insert(port, (filepath, listen)); - self.unix_ipc_port_map = Some(map); + fn set_port_map(&mut self, new_port_map: HashMap) -> Result<(), ()> { + if self.net_index != 0 { + return Err(()); + } + + self.tsi_port_map.replace(new_port_map); + Ok(()) } - } - fn set_gpu_virgl_flags(&mut self, virgl_flags: u32) { - self.gpu_virgl_flags = Some(virgl_flags); - } + #[cfg(feature = "tee")] + fn set_tee_config_file(&mut self, filepath: PathBuf) { + self.tee_config_file = Some(filepath); + } - fn set_gpu_shm_size(&mut self, shm_size: usize) { - self.gpu_shm_size = Some(shm_size); - } + #[cfg(feature = "tee")] + fn get_tee_config_file(&self) -> Option { + self.tee_config_file.clone() + } - fn set_vmm_uid(&mut self, vmm_uid: libc::uid_t) { - self.vmm_uid = Some(vmm_uid); - } + fn add_vsock_port(&mut self, port: u32, filepath: PathBuf, listen: bool) { + if let Some(ref mut map) = &mut self.unix_ipc_port_map { + map.insert(port, (filepath, listen)); + } else { + let mut map: HashMap = HashMap::new(); + map.insert(port, (filepath, listen)); + self.unix_ipc_port_map = Some(map); + } + } - fn set_vmm_gid(&mut self, vmm_gid: libc::gid_t) { - self.vmm_gid = Some(vmm_gid); - } -} + fn set_gpu_virgl_flags(&mut self, virgl_flags: u32) { + self.gpu_virgl_flags = Some(virgl_flags); + } -#[cfg(feature = "aws-nitro")] -impl TryFrom for NitroEnclave { - type Error = i32; + fn set_gpu_shm_size(&mut self, shm_size: usize) { + self.gpu_shm_size = Some(shm_size); + } - fn try_from(ctx: ContextConfig) -> Result { - let vm_config = ctx.vmr.vm_config(); + fn set_vmm_uid(&mut self, vmm_uid: libc::uid_t) { + self.vmm_uid = Some(vmm_uid); + } - let Some(mem_size_mib) = vm_config.mem_size_mib else { - error!("memory size not configured"); - return Err(-libc::EINVAL); - }; + fn set_vmm_gid(&mut self, vmm_gid: libc::gid_t) { + self.vmm_gid = Some(vmm_gid); + } + } - let Some(vcpus) = vm_config.vcpu_count else { - error!("vCPU count not configured"); - return Err(-libc::EINVAL); - }; + #[cfg(feature = "aws-nitro")] + impl TryFrom for NitroEnclave { + type Error = i32; - let rootfs = if let Some(path) = &ctx.vmr.fs.first() { - path.shared_dir.clone() - } else { - error!("rootfs path required"); - return Err(-libc::EINVAL); - }; + fn try_from(ctx: ContextConfig) -> Result { + let vm_config = ctx.vmr.vm_config(); - let Some(exec_path) = ctx.exec_path else { - error!("exec path not specified"); - return Err(-libc::EINVAL); - }; + let Some(mem_size_mib) = vm_config.mem_size_mib else { + error!("memory size not configured"); + return Err(-libc::EINVAL); + }; - let Some(exec_env) = ctx.env else { - error!("execution env not specified"); - return Err(-libc::EINVAL); - }; + let Some(vcpus) = vm_config.vcpu_count else { + error!("vCPU count not configured"); + return Err(-libc::EINVAL); + }; - let Some(exec_args) = ctx.args else { - error!("execution args not specified"); - return Err(-libc::EINVAL); - }; + let rootfs = if let Some(path) = &ctx.vmr.fs.first() { + path.shared_dir.clone() + } else { + error!("rootfs path required"); + return Err(-libc::EINVAL); + }; - let net_unixfd = { - let mut list = ctx.vmr.net.list; - let len = list.len(); - match len { - 0 => None, - 1 => { - let device = list.pop_front().unwrap(); - let device = device.lock().unwrap(); + let Some(exec_path) = ctx.exec_path else { + error!("exec path not specified"); + return Err(-libc::EINVAL); + }; - let fd = match device.cfg_backend { - VirtioNetBackend::UnixstreamFd(fd) => RawFd::from(fd), - _ => return Err(libc::EINVAL), - }; + let Some(exec_env) = ctx.env else { + error!("execution env not specified"); + return Err(-libc::EINVAL); + }; - Some(fd) - } - _ => { - error!( + let Some(exec_args) = ctx.args else { + error!("execution args not specified"); + return Err(-libc::EINVAL); + }; + + let net_unixfd = { + let mut list = ctx.vmr.net.list; + let len = list.len(); + match len { + 0 => None, + 1 => { + let device = list.pop_front().unwrap(); + let device = device.lock().unwrap(); + + let fd = match device.cfg_backend { + VirtioNetBackend::UnixstreamFd(fd) => RawFd::from(fd), + _ => return Err(libc::EINVAL), + }; + + Some(fd) + } + _ => { + error!( "more than one network interface configured (max 1 allowed, found {len})" ); - return Err(-libc::EINVAL); + return Err(-libc::EINVAL); + } } - } - }; - - let Some(output_path) = ctx.console_output else { - error!("console output path not specified"); - return Err(-libc::EINVAL); - }; + }; - let debug = KRUN_NITRO_DEBUG.lock().unwrap(); + let Some(output_path) = ctx.console_output else { + error!("console output path not specified"); + return Err(-libc::EINVAL); + }; - Ok(Self { - mem_size_mib, - vcpus, - rootfs, - exec_path, - exec_args, - exec_env, - net_unixfd, - output_path, - debug: *debug, - }) + let debug = KRUN_NITRO_DEBUG.lock().unwrap(); + + Ok(Self { + mem_size_mib, + vcpus, + rootfs, + exec_path, + exec_args, + exec_env, + net_unixfd, + output_path, + debug: *debug, + }) + } } -} -// TODO: Use this everywhere instead of the manual match -#[allow(dead_code)] -fn with_cfg(ctx_id: u32, f: impl FnOnce(&mut ContextConfig) -> i32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => f(ctx_cfg.get_mut()), - Entry::Vacant(_) => -libc::ENOENT, + // TODO: Use this everywhere instead of the manual match + #[allow(dead_code)] + fn with_cfg(ctx_id: u32, f: impl FnOnce(&mut ContextConfig) -> i32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => f(ctx_cfg.get_mut()), + Entry::Vacant(_) => -libc::ENOENT, + } } -} -static CTX_MAP: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); -static CTX_IDS: AtomicI32 = AtomicI32::new(0); + static CTX_MAP: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + static CTX_IDS: AtomicI32 = AtomicI32::new(0); -fn log_level_to_filter_str(level: u32) -> &'static str { - match level { - 0 => "off", - 1 => "error", - 2 => "warn", - 3 => "info", - 4 => "debug", - _ => "trace", + fn log_level_to_filter_str(level: u32) -> &'static str { + match level { + 0 => "off", + 1 => "error", + 2 => "warn", + 3 => "info", + 4 => "debug", + _ => "trace", + } } -} -#[no_mangle] -pub extern "C" fn krun_set_log_level(level: u32) -> i32 { - let filter = log_level_to_filter_str(level); - env_logger::Builder::from_env(Env::default().default_filter_or(filter)) - .format_timestamp_micros() - .init(); + #[no_mangle] + pub extern "C" fn krun_set_log_level(level: u32) -> i32 { + let filter = log_level_to_filter_str(level); + env_logger::Builder::from_env(Env::default().default_filter_or(filter)) + .format_timestamp_micros() + .init(); - #[cfg(feature = "aws-nitro")] - { - // Notify krun-awsnitro to enable debug for log level. - if level == 4 { - let mut debug = KRUN_NITRO_DEBUG.lock().unwrap(); + #[cfg(feature = "aws-nitro")] + { + // Notify krun-awsnitro to enable debug for log level. + if level == 4 { + let mut debug = KRUN_NITRO_DEBUG.lock().unwrap(); - *debug = true; + *debug = true; + } } - } - KRUN_SUCCESS -} + KRUN_SUCCESS + } -mod log_defs { - pub const KRUN_LOG_STYLE_AUTO: u32 = 0; - pub const KRUN_LOG_STYLE_ALWAYS: u32 = 1; - pub const KRUN_LOG_STYLE_NEVER: u32 = 2; - pub const KRUN_LOG_OPTION_NO_ENV: u32 = 1; -} + mod log_defs { + pub const KRUN_LOG_STYLE_AUTO: u32 = 0; + pub const KRUN_LOG_STYLE_ALWAYS: u32 = 1; + pub const KRUN_LOG_STYLE_NEVER: u32 = 2; + pub const KRUN_LOG_OPTION_NO_ENV: u32 = 1; + } -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_init_log(target: RawFd, level: u32, style: u32, options: u32) -> i32 { - let target = match target { + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_init_log( + target: RawFd, + level: u32, + style: u32, + options: u32, + ) -> i32 { + let target = match target { ..-1 => return -libc::EINVAL, -1 => Target::default(), 0 /* stdin */ => return -libc::EINVAL, @@ -480,787 +502,748 @@ pub unsafe extern "C" fn krun_init_log(target: RawFd, level: u32, style: u32, op fd => Target::Pipe(Box::new(File::from_raw_fd(fd))), }; - let filter = log_level_to_filter_str(level); - - let write_style = match style { - log_defs::KRUN_LOG_STYLE_AUTO => "auto", - log_defs::KRUN_LOG_STYLE_ALWAYS => "always", - log_defs::KRUN_LOG_STYLE_NEVER => "never", - _ => return -libc::EINVAL, - }; - - let use_env = match options { - 0 => true, - log_defs::KRUN_LOG_OPTION_NO_ENV => false, - _ => return -libc::EINVAL, - }; - - let mut builder = if use_env { - env_logger::Builder::from_env( - Env::new() - .default_filter_or(filter) - .default_write_style_or(write_style), - ) - } else { - let mut builder = env_logger::Builder::new(); - builder.parse_filters(filter).parse_write_style(write_style); - builder - }; - builder.format_timestamp_micros().target(target).init(); + let filter = log_level_to_filter_str(level); - KRUN_SUCCESS -} + let write_style = match style { + log_defs::KRUN_LOG_STYLE_AUTO => "auto", + log_defs::KRUN_LOG_STYLE_ALWAYS => "always", + log_defs::KRUN_LOG_STYLE_NEVER => "never", + _ => return -libc::EINVAL, + }; -#[no_mangle] -pub extern "C" fn krun_create_ctx() -> i32 { - let shutdown_efd = if cfg!(target_arch = "aarch64") && cfg!(target_os = "macos") { - Some(EventFd::new(utils::eventfd::EFD_NONBLOCK).unwrap()) - } else { - None - }; + let use_env = match options { + 0 => true, + log_defs::KRUN_LOG_OPTION_NO_ENV => false, + _ => return -libc::EINVAL, + }; - let ctx_cfg = { - ContextConfig { - krunfw: KrunfwBindings::new(), - shutdown_efd, - ..Default::default() - } - }; + let mut builder = if use_env { + env_logger::Builder::from_env( + Env::new() + .default_filter_or(filter) + .default_write_style_or(write_style), + ) + } else { + let mut builder = env_logger::Builder::new(); + builder.parse_filters(filter).parse_write_style(write_style); + builder + }; + builder.format_timestamp_micros().target(target).init(); - let ctx_id = CTX_IDS.fetch_add(1, Ordering::SeqCst); - if ctx_id == i32::MAX || CTX_MAP.lock().unwrap().contains_key(&(ctx_id as u32)) { - // libkrun is not intended to be used as a daemon for managing VMs. - panic!("Context ID namespace exhausted"); + KRUN_SUCCESS } - CTX_MAP.lock().unwrap().insert(ctx_id as u32, ctx_cfg); - ctx_id -} + #[no_mangle] + pub extern "C" fn krun_create_ctx() -> i32 { + let shutdown_efd = if cfg!(target_arch = "aarch64") && cfg!(target_os = "macos") { + Some(EventFd::new(utils::eventfd::EFD_NONBLOCK).unwrap()) + } else { + None + }; -#[no_mangle] -pub extern "C" fn krun_free_ctx(ctx_id: u32) -> i32 { - match CTX_MAP.lock().unwrap().remove(&ctx_id) { - Some(_) => KRUN_SUCCESS, - None => -libc::ENOENT, - } -} + let ctx_cfg = { + ContextConfig { + krunfw: KrunfwBindings::new(), + shutdown_efd, + ..Default::default() + } + }; -#[no_mangle] -pub extern "C" fn krun_set_vm_config(ctx_id: u32, num_vcpus: u8, ram_mib: u32) -> i32 { - let mem_size_mib: usize = match ram_mib.try_into() { - Ok(size) => size, - Err(e) => { - warn!("Error parsing the amount of RAM: {e:?}"); - return -libc::EINVAL; + let ctx_id = CTX_IDS.fetch_add(1, Ordering::SeqCst); + if ctx_id == i32::MAX || CTX_MAP.lock().unwrap().contains_key(&(ctx_id as u32)) { + // libkrun is not intended to be used as a daemon for managing VMs. + panic!("Context ID namespace exhausted"); } - }; + CTX_MAP.lock().unwrap().insert(ctx_id as u32, ctx_cfg); - let vm_config = VmConfig { - vcpu_count: Some(num_vcpus), - mem_size_mib: Some(mem_size_mib), - ht_enabled: Some(false), - cpu_template: None, - }; + ctx_id + } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - if ctx_cfg.get_mut().vmr.set_vm_config(&vm_config).is_err() { - return -libc::EINVAL; - } + #[no_mangle] + pub extern "C" fn krun_free_ctx(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().remove(&ctx_id) { + Some(_) => KRUN_SUCCESS, + None => -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, } - KRUN_SUCCESS -} + #[no_mangle] + pub extern "C" fn krun_set_vm_config(ctx_id: u32, num_vcpus: u8, ram_mib: u32) -> i32 { + let mem_size_mib: usize = match ram_mib.try_into() { + Ok(size) => size, + Err(e) => { + warn!("Error parsing the amount of RAM: {e:?}"); + return -libc::EINVAL; + } + }; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(not(feature = "tee"))] -pub unsafe extern "C" fn krun_set_root(ctx_id: u32, c_root_path: *const c_char) -> i32 { - let root_path = match CStr::from_ptr(c_root_path).to_str() { - Ok(root) => root, - Err(_) => return -libc::EINVAL, - }; + let vm_config = VmConfig { + vcpu_count: Some(num_vcpus), + mem_size_mib: Some(mem_size_mib), + ht_enabled: Some(false), + cpu_template: None, + }; - let fs_id = "/dev/root".to_string(); - let shared_dir = root_path.to_string(); - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.add_fs_device(FsDeviceConfig { - fs_id, - shared_dir, - // Default to a conservative 512 MB window. - shm_size: Some(1 << 29), - allow_root_dir_delete: false, - read_only: false, - }); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + if ctx_cfg.get_mut().vmr.set_vm_config(&vm_config).is_err() { + return -libc::EINVAL; + } + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(not(feature = "tee"))] -pub unsafe extern "C" fn krun_add_virtiofs( - ctx_id: u32, - c_tag: *const c_char, - c_path: *const c_char, -) -> i32 { - krun_add_virtiofs3(ctx_id, c_tag, c_path, 0, false) -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(not(feature = "tee"))] -pub unsafe extern "C" fn krun_add_virtiofs2( - ctx_id: u32, - c_tag: *const c_char, - c_path: *const c_char, - shm_size: u64, -) -> i32 { - krun_add_virtiofs3(ctx_id, c_tag, c_path, shm_size, false) -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(not(feature = "tee"))] -pub unsafe extern "C" fn krun_add_virtiofs3( - ctx_id: u32, - c_tag: *const c_char, - c_path: *const c_char, - shm_size: u64, - read_only: bool, -) -> i32 { - if c_tag.is_null() || c_path.is_null() { - return -libc::EINVAL; - } - - let tag = match CStr::from_ptr(c_tag).to_str() { - Ok(tag) => tag, - Err(_) => return -libc::EINVAL, - }; - let path = match CStr::from_ptr(c_path).to_str() { - Ok(path) => path, - Err(_) => return -libc::EINVAL, - }; - let shm = if shm_size > 0 { - match shm_size.try_into() { - Ok(s) => Some(s), + KRUN_SUCCESS + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(not(feature = "tee"))] + pub unsafe extern "C" fn krun_set_root(ctx_id: u32, c_root_path: *const c_char) -> i32 { + let root_path = match CStr::from_ptr(c_root_path).to_str() { + Ok(root) => root, Err(_) => return -libc::EINVAL, - } - } else { - None - }; + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.add_fs_device(FsDeviceConfig { - fs_id: tag.to_string(), - shared_dir: path.to_string(), - shm_size: shm, - allow_root_dir_delete: false, - read_only, - }); + let fs_id = "/dev/root".to_string(); + let shared_dir = root_path.to_string(); + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.add_fs_device(FsDeviceConfig { + fs_id, + shared_dir, + // Default to a conservative 512 MB window. + shm_size: Some(1 << 29), + allow_root_dir_delete: false, + read_only: false, + }); + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(not(feature = "tee"))] -pub unsafe extern "C" fn krun_set_mapped_volumes( - _ctx_id: u32, - _c_mapped_volumes: *const *const c_char, -) -> i32 { - -libc::EINVAL -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "blk")] -pub unsafe extern "C" fn krun_add_disk( - ctx_id: u32, - c_block_id: *const c_char, - c_disk_path: *const c_char, - read_only: bool, -) -> i32 { - let disk_path = match CStr::from_ptr(c_disk_path).to_str() { - Ok(disk) => disk, - Err(_) => return -libc::EINVAL, - }; - - let block_id = match CStr::from_ptr(c_block_id).to_str() { - Ok(block_id) => block_id, - Err(_) => return -libc::EINVAL, - }; - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - let block_device_config = BlockDeviceConfig { - block_id: block_id.to_string(), - cache_type: CacheType::auto(disk_path), - disk_image_path: disk_path.to_string(), - disk_image_format: ImageType::Raw, - is_disk_read_only: read_only, - direct_io: false, - #[cfg(not(target_os = "macos"))] - sync_mode: SyncMode::Full, - #[cfg(target_os = "macos")] - sync_mode: SyncMode::Relaxed, - }; - cfg.add_block_cfg(block_device_config); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "blk")] -pub unsafe extern "C" fn krun_add_disk2( - ctx_id: u32, - c_block_id: *const c_char, - c_disk_path: *const c_char, - disk_format: u32, - read_only: bool, -) -> i32 { - let disk_path = match CStr::from_ptr(c_disk_path).to_str() { - Ok(disk) => disk, - Err(_) => return -libc::EINVAL, - }; - - let block_id = match CStr::from_ptr(c_block_id).to_str() { - Ok(block_id) => block_id, - Err(_) => return -libc::EINVAL, - }; - let format = match ImageType::try_from(disk_format) { - Ok(format) => format, - Err(_) => return -libc::EINVAL, - }; - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - let block_device_config = BlockDeviceConfig { - block_id: block_id.to_string(), - cache_type: CacheType::auto(disk_path), - disk_image_path: disk_path.to_string(), - disk_image_format: format, - is_disk_read_only: read_only, - direct_io: false, - #[cfg(not(target_os = "macos"))] - sync_mode: SyncMode::Full, - #[cfg(target_os = "macos")] - sync_mode: SyncMode::Relaxed, - }; - cfg.add_block_cfg(block_device_config); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "blk")] -pub unsafe extern "C" fn krun_add_disk3( - ctx_id: u32, - c_block_id: *const c_char, - c_disk_path: *const c_char, - disk_format: u32, - read_only: bool, - direct_io: bool, - sync_mode: u32, -) -> i32 { - let disk_path = match CStr::from_ptr(c_disk_path).to_str() { - Ok(disk) => disk, - Err(_) => return -libc::EINVAL, - }; + KRUN_SUCCESS + } - let block_id = match CStr::from_ptr(c_block_id).to_str() { - Ok(block_id) => block_id, - Err(_) => return -libc::EINVAL, - }; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(not(feature = "tee"))] + pub unsafe extern "C" fn krun_add_virtiofs( + ctx_id: u32, + c_tag: *const c_char, + c_path: *const c_char, + ) -> i32 { + krun_add_virtiofs3(ctx_id, c_tag, c_path, 0, false) + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(not(feature = "tee"))] + pub unsafe extern "C" fn krun_add_virtiofs2( + ctx_id: u32, + c_tag: *const c_char, + c_path: *const c_char, + shm_size: u64, + ) -> i32 { + krun_add_virtiofs3(ctx_id, c_tag, c_path, shm_size, false) + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(not(feature = "tee"))] + pub unsafe extern "C" fn krun_add_virtiofs3( + ctx_id: u32, + c_tag: *const c_char, + c_path: *const c_char, + shm_size: u64, + read_only: bool, + ) -> i32 { + if c_tag.is_null() || c_path.is_null() { + return -libc::EINVAL; + } - let format = match ImageType::try_from(disk_format) { - Ok(fmt) => fmt, - Err(_) => return -libc::EINVAL, - }; + let tag = match CStr::from_ptr(c_tag).to_str() { + Ok(tag) => tag, + Err(_) => return -libc::EINVAL, + }; + let path = match CStr::from_ptr(c_path).to_str() { + Ok(path) => path, + Err(_) => return -libc::EINVAL, + }; - let sync_mode = match SyncMode::try_from(sync_mode) { - Ok(mode) => mode, - Err(_) => return -libc::EINVAL, - }; + let shm = if shm_size > 0 { + match shm_size.try_into() { + Ok(s) => Some(s), + Err(_) => return -libc::EINVAL, + } + } else { + None + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - let block_device_config = BlockDeviceConfig { - block_id: block_id.to_string(), - cache_type: CacheType::auto(disk_path), - disk_image_path: disk_path.to_string(), - disk_image_format: format, - is_disk_read_only: read_only, - direct_io, - sync_mode, - }; - cfg.add_block_cfg(block_device_config); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.add_fs_device(FsDeviceConfig { + fs_id: tag.to_string(), + shared_dir: path.to_string(), + shm_size: shm, + allow_root_dir_delete: false, + read_only, + }); + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "blk")] -pub unsafe extern "C" fn krun_set_root_disk(ctx_id: u32, c_disk_path: *const c_char) -> i32 { - let disk_path = match CStr::from_ptr(c_disk_path).to_str() { - Ok(disk) => disk, - Err(_) => return -libc::EINVAL, - }; + KRUN_SUCCESS + } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - let block_device_config = BlockDeviceConfig { - block_id: "root".to_string(), - cache_type: CacheType::auto(disk_path), - disk_image_path: disk_path.to_string(), - disk_image_format: ImageType::Raw, - is_disk_read_only: false, - direct_io: false, - #[cfg(not(target_os = "macos"))] - sync_mode: SyncMode::Full, - #[cfg(target_os = "macos")] - sync_mode: SyncMode::Relaxed, - }; - cfg.set_root_block_cfg(block_device_config); - } - Entry::Vacant(_) => return -libc::ENOENT, + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(not(feature = "tee"))] + pub unsafe extern "C" fn krun_set_mapped_volumes( + _ctx_id: u32, + _c_mapped_volumes: *const *const c_char, + ) -> i32 { + -libc::EINVAL } - KRUN_SUCCESS -} + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "blk")] + pub unsafe extern "C" fn krun_add_disk( + ctx_id: u32, + c_block_id: *const c_char, + c_disk_path: *const c_char, + read_only: bool, + ) -> i32 { + let disk_path = match CStr::from_ptr(c_disk_path).to_str() { + Ok(disk) => disk, + Err(_) => return -libc::EINVAL, + }; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "blk")] -pub unsafe extern "C" fn krun_set_data_disk(ctx_id: u32, c_disk_path: *const c_char) -> i32 { - let disk_path = match CStr::from_ptr(c_disk_path).to_str() { - Ok(disk) => disk, - Err(_) => return -libc::EINVAL, - }; + let block_id = match CStr::from_ptr(c_block_id).to_str() { + Ok(block_id) => block_id, + Err(_) => return -libc::EINVAL, + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - let block_device_config = BlockDeviceConfig { - block_id: "data".to_string(), - cache_type: CacheType::auto(disk_path), - disk_image_path: disk_path.to_string(), - disk_image_format: ImageType::Raw, - is_disk_read_only: false, - direct_io: false, - #[cfg(not(target_os = "macos"))] - sync_mode: SyncMode::Full, - #[cfg(target_os = "macos")] - sync_mode: SyncMode::Relaxed, - }; - cfg.set_data_block_cfg(block_device_config); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -/* - * Send the VFKIT magic after establishing the connection, - * as required by gvproxy in vfkit mode. - */ -#[cfg(feature = "net")] -const NET_FLAG_VFKIT: u32 = 1 << 0; -#[cfg(feature = "net")] -const NET_FLAG_DHCP_CLIENT: u32 = 1 << 1; -#[cfg(feature = "net")] -const NET_FLAG_ALL: u32 = NET_FLAG_VFKIT | NET_FLAG_DHCP_CLIENT; - -/* Taken from uapi/linux/virtio_net.h */ -#[cfg(feature = "net")] -const NET_FEATURE_CSUM: u32 = 1 << 0; -#[cfg(feature = "net")] -const NET_FEATURE_GUEST_CSUM: u32 = 1 << 1; -#[cfg(feature = "net")] -const NET_FEATURE_GUEST_TSO4: u32 = 1 << 7; -#[cfg(feature = "net")] -const NET_FEATURE_GUEST_TSO6: u32 = 1 << 8; -#[cfg(feature = "net")] -const NET_FEATURE_GUEST_UFO: u32 = 1 << 10; -#[cfg(feature = "net")] -const NET_FEATURE_HOST_TSO4: u32 = 1 << 11; -#[cfg(feature = "net")] -const NET_FEATURE_HOST_TSO6: u32 = 1 << 12; -#[cfg(feature = "net")] -const NET_FEATURE_HOST_UFO: u32 = 1 << 14; -/* - * These are the flags enabled by default on each virtio-net instance - * before the introduction of "krun_add_net_*". They are now used in - * the legacy API ("krun_set_passt_fd" and "krun_set_gvproxy_path") - * for compatiblity reasons. - */ -#[cfg(feature = "net")] -const NET_COMPAT_FEATURES: u32 = NET_FEATURE_CSUM - | NET_FEATURE_GUEST_CSUM - | NET_FEATURE_GUEST_TSO4 - | NET_FEATURE_GUEST_UFO - | NET_FEATURE_HOST_TSO4 - | NET_FEATURE_HOST_UFO; -#[cfg(feature = "net")] -const NET_ALL_FEATURES: u32 = NET_FEATURE_CSUM - | NET_FEATURE_GUEST_CSUM - | NET_FEATURE_GUEST_TSO4 - | NET_FEATURE_GUEST_TSO6 - | NET_FEATURE_GUEST_UFO - | NET_FEATURE_HOST_TSO4 - | NET_FEATURE_HOST_TSO6 - | NET_FEATURE_HOST_UFO; - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "net")] -pub unsafe extern "C" fn krun_add_net_unixstream( - ctx_id: u32, - c_path: *const c_char, - fd: c_int, - c_mac: *const u8, - features: u32, - flags: u32, -) -> i32 { - let path = if !c_path.is_null() { - match CStr::from_ptr(c_path).to_str() { - Ok(path) => Some(PathBuf::from(path)), - Err(_) => None, - } - } else { - None - }; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + let block_device_config = BlockDeviceConfig { + block_id: block_id.to_string(), + cache_type: CacheType::auto(disk_path), + disk_image_path: disk_path.to_string(), + disk_image_format: ImageType::Raw, + is_disk_read_only: read_only, + direct_io: false, + #[cfg(not(target_os = "macos"))] + sync_mode: SyncMode::Full, + #[cfg(target_os = "macos")] + sync_mode: SyncMode::Relaxed, + }; + cfg.add_block_cfg(block_device_config); + } + Entry::Vacant(_) => return -libc::ENOENT, + } - if fd >= 0 && path.is_some() { - return -libc::EINVAL; - } - if fd < 0 && path.is_none() { - return -libc::EINVAL; + KRUN_SUCCESS } - let backend = if let Some(path) = path { - VirtioNetBackend::UnixstreamPath(path) - } else { - VirtioNetBackend::UnixstreamFd(fd) - }; - let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { - Ok(m) => m, - Err(_) => return -libc::EINVAL, - }; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "blk")] + pub unsafe extern "C" fn krun_add_disk2( + ctx_id: u32, + c_block_id: *const c_char, + c_disk_path: *const c_char, + disk_format: u32, + read_only: bool, + ) -> i32 { + let disk_path = match CStr::from_ptr(c_disk_path).to_str() { + Ok(disk) => disk, + Err(_) => return -libc::EINVAL, + }; - if (flags & !NET_FLAG_DHCP_CLIENT) != 0 { - return -libc::EINVAL; - } - let enable_dhcp_client: bool = flags & NET_FLAG_DHCP_CLIENT != 0; + let block_id = match CStr::from_ptr(c_block_id).to_str() { + Ok(block_id) => block_id, + Err(_) => return -libc::EINVAL, + }; - if (features & !NET_ALL_FEATURES) != 0 { - return -libc::EINVAL; - } + let format = match ImageType::try_from(disk_format) { + Ok(format) => format, + Err(_) => return -libc::EINVAL, + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - create_virtio_net(cfg, backend, mac, features); - if enable_dhcp_client { - cfg.vmr.dhcp_client = true; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + let block_device_config = BlockDeviceConfig { + block_id: block_id.to_string(), + cache_type: CacheType::auto(disk_path), + disk_image_path: disk_path.to_string(), + disk_image_format: format, + is_disk_read_only: read_only, + direct_io: false, + #[cfg(not(target_os = "macos"))] + sync_mode: SyncMode::Full, + #[cfg(target_os = "macos")] + sync_mode: SyncMode::Relaxed, + }; + cfg.add_block_cfg(block_device_config); } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "net")] -pub unsafe extern "C" fn krun_add_net_unixgram( - ctx_id: u32, - c_path: *const c_char, - fd: c_int, - c_mac: *const u8, - features: u32, - flags: u32, -) -> i32 { - let path = if !c_path.is_null() { - match CStr::from_ptr(c_path).to_str() { - Ok(path) => Some(PathBuf::from(path)), - Err(_) => None, - } - } else { - None - }; - if fd >= 0 && path.is_some() { - return -libc::EINVAL; - } - if fd < 0 && path.is_none() { - return -libc::EINVAL; + KRUN_SUCCESS } - let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { - Ok(m) => m, - Err(_) => return -libc::EINVAL, - }; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "blk")] + pub unsafe extern "C" fn krun_add_disk3( + ctx_id: u32, + c_block_id: *const c_char, + c_disk_path: *const c_char, + disk_format: u32, + read_only: bool, + direct_io: bool, + sync_mode: u32, + ) -> i32 { + let disk_path = match CStr::from_ptr(c_disk_path).to_str() { + Ok(disk) => disk, + Err(_) => return -libc::EINVAL, + }; - if (features & !NET_ALL_FEATURES) != 0 { - return -libc::EINVAL; - } + let block_id = match CStr::from_ptr(c_block_id).to_str() { + Ok(block_id) => block_id, + Err(_) => return -libc::EINVAL, + }; - if (flags & !NET_FLAG_ALL) != 0 { - return -libc::EINVAL; - } - let send_vfkit_magic: bool = flags & NET_FLAG_VFKIT != 0; - let enable_dhcp_client: bool = flags & NET_FLAG_DHCP_CLIENT != 0; + let format = match ImageType::try_from(disk_format) { + Ok(fmt) => fmt, + Err(_) => return -libc::EINVAL, + }; - let backend = if let Some(path) = path { - VirtioNetBackend::UnixgramPath(path, send_vfkit_magic) - } else { - VirtioNetBackend::UnixgramFd(fd) - }; + let sync_mode = match SyncMode::try_from(sync_mode) { + Ok(mode) => mode, + Err(_) => return -libc::EINVAL, + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - create_virtio_net(cfg, backend, mac, features); - if enable_dhcp_client { - cfg.vmr.dhcp_client = true; - } - } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(all(target_os = "linux", feature = "net"))] -pub unsafe extern "C" fn krun_add_net_tap( - ctx_id: u32, - c_tap_name: *const c_char, - c_mac: *const u8, - features: u32, - flags: u32, -) -> i32 { - let tap_name = match CStr::from_ptr(c_tap_name).to_str() { - Ok(tap_name) => tap_name.to_string(), - Err(e) => { - debug!("Error parsing tap_name: {e:?}"); - return -libc::EINVAL; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + let block_device_config = BlockDeviceConfig { + block_id: block_id.to_string(), + cache_type: CacheType::auto(disk_path), + disk_image_path: disk_path.to_string(), + disk_image_format: format, + is_disk_read_only: read_only, + direct_io, + sync_mode, + }; + cfg.add_block_cfg(block_device_config); + } + Entry::Vacant(_) => return -libc::ENOENT, } - }; - - let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { - Ok(m) => m, - Err(_) => return -libc::EINVAL, - }; - if (features & !NET_ALL_FEATURES) != 0 { - return -libc::EINVAL; + KRUN_SUCCESS } - if features & (NET_FEATURE_GUEST_TSO4 | NET_FEATURE_GUEST_TSO6 | NET_FEATURE_GUEST_UFO) != 0 - && features & NET_FEATURE_GUEST_CSUM == 0 - { - debug!("Network tap backend requires GUEST_CSUM to be requested if any of GUEST_TSO4, GUEST_TSO6 and/or GUEST_UFO are required"); - return -libc::EINVAL; - } + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "blk")] + pub unsafe extern "C" fn krun_set_root_disk(ctx_id: u32, c_disk_path: *const c_char) -> i32 { + let disk_path = match CStr::from_ptr(c_disk_path).to_str() { + Ok(disk) => disk, + Err(_) => return -libc::EINVAL, + }; - if (flags & !NET_FLAG_DHCP_CLIENT) != 0 { - return -libc::EINVAL; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + let block_device_config = BlockDeviceConfig { + block_id: "root".to_string(), + cache_type: CacheType::auto(disk_path), + disk_image_path: disk_path.to_string(), + disk_image_format: ImageType::Raw, + is_disk_read_only: false, + direct_io: false, + #[cfg(not(target_os = "macos"))] + sync_mode: SyncMode::Full, + #[cfg(target_os = "macos")] + sync_mode: SyncMode::Relaxed, + }; + cfg.set_root_block_cfg(block_device_config); + } + Entry::Vacant(_) => return -libc::ENOENT, + } + + KRUN_SUCCESS } - let enable_dhcp_client: bool = flags & NET_FLAG_DHCP_CLIENT != 0; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - create_virtio_net(cfg, VirtioNetBackend::Tap(tap_name), mac, features); - if enable_dhcp_client { - cfg.vmr.dhcp_client = true; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "blk")] + pub unsafe extern "C" fn krun_set_data_disk(ctx_id: u32, c_disk_path: *const c_char) -> i32 { + let disk_path = match CStr::from_ptr(c_disk_path).to_str() { + Ok(disk) => disk, + Err(_) => return -libc::EINVAL, + }; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + let block_device_config = BlockDeviceConfig { + block_id: "data".to_string(), + cache_type: CacheType::auto(disk_path), + disk_image_path: disk_path.to_string(), + disk_image_format: ImageType::Raw, + is_disk_read_only: false, + direct_io: false, + #[cfg(not(target_os = "macos"))] + sync_mode: SyncMode::Full, + #[cfg(target_os = "macos")] + sync_mode: SyncMode::Relaxed, + }; + cfg.set_data_block_cfg(block_device_config); } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + + KRUN_SUCCESS } - KRUN_SUCCESS -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(all(not(target_os = "linux"), feature = "net"))] -pub unsafe extern "C" fn krun_add_net_tap( - _ctx_id: u32, - _c_tap_name: *const c_char, - _c_mac: *const u8, - _features: u32, - _flags: u32, -) -> i32 { - -libc::EINVAL -} + /* + * Send the VFKIT magic after establishing the connection, + * as required by gvproxy in vfkit mode. + */ + #[cfg(feature = "net")] + const NET_FLAG_VFKIT: u32 = 1 << 0; + #[cfg(feature = "net")] + const NET_FLAG_DHCP_CLIENT: u32 = 1 << 1; + #[cfg(feature = "net")] + const NET_FLAG_ALL: u32 = NET_FLAG_VFKIT | NET_FLAG_DHCP_CLIENT; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "net")] -pub unsafe extern "C" fn krun_set_passt_fd(ctx_id: u32, fd: c_int) -> i32 { - if fd < 0 { - return -libc::EINVAL; - } + /* Taken from uapi/linux/virtio_net.h */ + #[cfg(feature = "net")] + const NET_FEATURE_CSUM: u32 = 1 << 0; + #[cfg(feature = "net")] + const NET_FEATURE_GUEST_CSUM: u32 = 1 << 1; + #[cfg(feature = "net")] + const NET_FEATURE_GUEST_TSO4: u32 = 1 << 7; + #[cfg(feature = "net")] + const NET_FEATURE_GUEST_TSO6: u32 = 1 << 8; + #[cfg(feature = "net")] + const NET_FEATURE_GUEST_UFO: u32 = 1 << 10; + #[cfg(feature = "net")] + const NET_FEATURE_HOST_TSO4: u32 = 1 << 11; + #[cfg(feature = "net")] + const NET_FEATURE_HOST_TSO6: u32 = 1 << 12; + #[cfg(feature = "net")] + const NET_FEATURE_HOST_UFO: u32 = 1 << 14; + /* + * These are the flags enabled by default on each virtio-net instance + * before the introduction of "krun_add_net_*". They are now used in + * the legacy API ("krun_set_passt_fd" and "krun_set_gvproxy_path") + * for compatiblity reasons. + */ + #[cfg(feature = "net")] + const NET_COMPAT_FEATURES: u32 = NET_FEATURE_CSUM + | NET_FEATURE_GUEST_CSUM + | NET_FEATURE_GUEST_TSO4 + | NET_FEATURE_GUEST_UFO + | NET_FEATURE_HOST_TSO4 + | NET_FEATURE_HOST_UFO; + #[cfg(feature = "net")] + const NET_ALL_FEATURES: u32 = NET_FEATURE_CSUM + | NET_FEATURE_GUEST_CSUM + | NET_FEATURE_GUEST_TSO4 + | NET_FEATURE_GUEST_TSO6 + | NET_FEATURE_GUEST_UFO + | NET_FEATURE_HOST_TSO4 + | NET_FEATURE_HOST_TSO6 + | NET_FEATURE_HOST_UFO; + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "net")] + pub unsafe extern "C" fn krun_add_net_unixstream( + ctx_id: u32, + c_path: *const c_char, + fd: c_int, + c_mac: *const u8, + features: u32, + flags: u32, + ) -> i32 { + let path = if !c_path.is_null() { + match CStr::from_ptr(c_path).to_str() { + Ok(path) => Some(PathBuf::from(path)), + Err(_) => None, + } + } else { + None + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - // The legacy interface only supports a single network interface. - if cfg.net_index != 0 { - return -libc::EINVAL; + if fd >= 0 && path.is_some() { + return -libc::EINVAL; + } + if fd < 0 && path.is_none() { + return -libc::EINVAL; + } + let backend = if let Some(path) = path { + VirtioNetBackend::UnixstreamPath(path) + } else { + VirtioNetBackend::UnixstreamFd(fd) + }; + + let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { + Ok(m) => m, + Err(_) => return -libc::EINVAL, + }; + + if (flags & !NET_FLAG_DHCP_CLIENT) != 0 { + return -libc::EINVAL; + } + let enable_dhcp_client: bool = flags & NET_FLAG_DHCP_CLIENT != 0; + + if (features & !NET_ALL_FEATURES) != 0 { + return -libc::EINVAL; + } + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + create_virtio_net(cfg, backend, mac, features); + if enable_dhcp_client { + cfg.vmr.dhcp_client = true; + } } - cfg.legacy_net_cfg = Some(LegacyNetworkConfig::VirtioNetPasst(fd)); + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + KRUN_SUCCESS } - KRUN_SUCCESS -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "net")] -pub unsafe extern "C" fn krun_set_gvproxy_path(ctx_id: u32, c_path: *const c_char) -> i32 { - let path_str = match CStr::from_ptr(c_path).to_str() { - Ok(path) => path, - Err(e) => { - debug!("Error parsing gvproxy_path: {e:?}"); + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "net")] + pub unsafe extern "C" fn krun_add_net_unixgram( + ctx_id: u32, + c_path: *const c_char, + fd: c_int, + c_mac: *const u8, + features: u32, + flags: u32, + ) -> i32 { + let path = if !c_path.is_null() { + match CStr::from_ptr(c_path).to_str() { + Ok(path) => Some(PathBuf::from(path)), + Err(_) => None, + } + } else { + None + }; + + if fd >= 0 && path.is_some() { + return -libc::EINVAL; + } + if fd < 0 && path.is_none() { return -libc::EINVAL; } - }; - let path = PathBuf::from(path_str); + let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { + Ok(m) => m, + Err(_) => return -libc::EINVAL, + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - // The legacy interface only supports a single network interface. - if cfg.net_index != 0 { - return -libc::EINVAL; + if (features & !NET_ALL_FEATURES) != 0 { + return -libc::EINVAL; + } + + if (flags & !NET_FLAG_ALL) != 0 { + return -libc::EINVAL; + } + let send_vfkit_magic: bool = flags & NET_FLAG_VFKIT != 0; + let enable_dhcp_client: bool = flags & NET_FLAG_DHCP_CLIENT != 0; + + let backend = if let Some(path) = path { + VirtioNetBackend::UnixgramPath(path, send_vfkit_magic) + } else { + VirtioNetBackend::UnixgramFd(fd) + }; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + create_virtio_net(cfg, backend, mac, features); + if enable_dhcp_client { + cfg.vmr.dhcp_client = true; + } } - cfg.legacy_net_cfg = Some(LegacyNetworkConfig::VirtioNetGvproxy(path)); + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + KRUN_SUCCESS } - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "net")] -pub unsafe extern "C" fn krun_set_net_mac(ctx_id: u32, c_mac: *const u8) -> i32 { - let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { - Ok(m) => m, - Err(_) => return -libc::EINVAL, - }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_net_mac(mac); + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "net")] + pub unsafe extern "C" fn krun_disable_tsi(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.disable_tsi = true; + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + KRUN_SUCCESS } - KRUN_SUCCESS -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_port_map(ctx_id: u32, c_port_map: *const *const c_char) -> i32 { - let mut port_map = HashMap::new(); - let port_map_array: &[*const c_char] = slice::from_raw_parts(c_port_map, MAX_ARGS); - for item in port_map_array.iter().take(MAX_ARGS) { - if item.is_null() { - break; - } else { - let s = match CStr::from_ptr(*item).to_str() { - Ok(s) => s, - Err(_) => return -libc::EINVAL, - }; - let port_tuple: Vec<&str> = s.split(':').collect(); - if port_tuple.len() != 2 { + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(all(target_os = "linux", feature = "net"))] + pub unsafe extern "C" fn krun_add_net_tap( + ctx_id: u32, + c_tap_name: *const c_char, + c_mac: *const u8, + features: u32, + flags: u32, + ) -> i32 { + let tap_name = match CStr::from_ptr(c_tap_name).to_str() { + Ok(tap_name) => tap_name.to_string(), + Err(e) => { + debug!("Error parsing tap_name: {e:?}"); return -libc::EINVAL; } - let host_port: u16 = match port_tuple[0].parse() { - Ok(p) => p, - Err(_) => return -libc::EINVAL, - }; - let guest_port: u16 = match port_tuple[1].parse() { - Ok(p) => p, - Err(_) => return -libc::EINVAL, - }; + }; - if port_map.contains_key(&guest_port) { - return -libc::EINVAL; + let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { + Ok(m) => m, + Err(_) => return -libc::EINVAL, + }; + + if (features & !NET_ALL_FEATURES) != 0 { + return -libc::EINVAL; + } + + if features & (NET_FEATURE_GUEST_TSO4 | NET_FEATURE_GUEST_TSO6 | NET_FEATURE_GUEST_UFO) != 0 + && features & NET_FEATURE_GUEST_CSUM == 0 + { + debug!( + "Network tap backend requires GUEST_CSUM to be requested if any of GUEST_TSO4, GUEST_TSO6 and/or GUEST_UFO are required" + ); + return -libc::EINVAL; + } + + if (flags & !NET_FLAG_DHCP_CLIENT) != 0 { + return -libc::EINVAL; + } + let enable_dhcp_client: bool = flags & NET_FLAG_DHCP_CLIENT != 0; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + create_virtio_net(cfg, VirtioNetBackend::Tap(tap_name), mac, features); + if enable_dhcp_client { + cfg.vmr.dhcp_client = true; + } } - for hp in port_map.values() { - if *hp == host_port { + Entry::Vacant(_) => return -libc::ENOENT, + } + KRUN_SUCCESS + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(all(not(target_os = "linux"), feature = "net"))] + pub unsafe extern "C" fn krun_add_net_tap( + _ctx_id: u32, + _c_tap_name: *const c_char, + _c_mac: *const u8, + _features: u32, + _flags: u32, + ) -> i32 { + -libc::EINVAL + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "net")] + pub unsafe extern "C" fn krun_set_passt_fd(ctx_id: u32, fd: c_int) -> i32 { + if fd < 0 { + return -libc::EINVAL; + } + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + // The legacy interface only supports a single network interface. + if cfg.net_index != 0 { return -libc::EINVAL; } + cfg.legacy_net_cfg = Some(LegacyNetworkConfig::VirtioNetPasst(fd)); } - port_map.insert(guest_port, host_port); + Entry::Vacant(_) => return -libc::ENOENT, } + KRUN_SUCCESS } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - if cfg.vsock_config == VsockConfig::Disabled { - return -libc::ENODEV; - } - if cfg.set_port_map(port_map).is_err() { + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "net")] + pub unsafe extern "C" fn krun_set_gvproxy_path(ctx_id: u32, c_path: *const c_char) -> i32 { + let path_str = match CStr::from_ptr(c_path).to_str() { + Ok(path) => path, + Err(e) => { + debug!("Error parsing gvproxy_path: {e:?}"); return -libc::EINVAL; } + }; + + let path = PathBuf::from(path_str); + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + // The legacy interface only supports a single network interface. + if cfg.net_index != 0 { + return -libc::EINVAL; + } + cfg.legacy_net_cfg = Some(LegacyNetworkConfig::VirtioNetGvproxy(path)); + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + KRUN_SUCCESS } - KRUN_SUCCESS -} + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "net")] + pub unsafe extern "C" fn krun_set_net_mac(ctx_id: u32, c_mac: *const u8) -> i32 { + let mac: [u8; 6] = match slice::from_raw_parts(c_mac, 6).try_into() { + Ok(m) => m, + Err(_) => return -libc::EINVAL, + }; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_rlimits(ctx_id: u32, c_rlimits: *const *const c_char) -> i32 { - let rlimits = if c_rlimits.is_null() { - return -libc::EINVAL; - } else { - let mut strvec = Vec::new(); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_net_mac(mac); + } + Entry::Vacant(_) => return -libc::ENOENT, + } + KRUN_SUCCESS + } - let array: &[*const c_char] = slice::from_raw_parts(c_rlimits, MAX_ARGS); - for item in array.iter().take(MAX_ARGS) { + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_port_map( + ctx_id: u32, + c_port_map: *const *const c_char, + ) -> i32 { + let mut port_map = HashMap::new(); + let port_map_array: &[*const c_char] = slice::from_raw_parts(c_port_map, MAX_ARGS); + for item in port_map_array.iter().take(MAX_ARGS) { if item.is_null() { break; } else { @@ -1268,1521 +1251,1641 @@ pub unsafe extern "C" fn krun_set_rlimits(ctx_id: u32, c_rlimits: *const *const Ok(s) => s, Err(_) => return -libc::EINVAL, }; - strvec.push(s); + let port_tuple: Vec<&str> = s.split(':').collect(); + if port_tuple.len() != 2 { + return -libc::EINVAL; + } + let host_port: u16 = match port_tuple[0].parse() { + Ok(p) => p, + Err(_) => return -libc::EINVAL, + }; + let guest_port: u16 = match port_tuple[1].parse() { + Ok(p) => p, + Err(_) => return -libc::EINVAL, + }; + + if port_map.contains_key(&guest_port) { + return -libc::EINVAL; + } + for hp in port_map.values() { + if *hp == host_port { + return -libc::EINVAL; + } + } + port_map.insert(guest_port, host_port); } } - format!("\"{}\"", strvec.join(",")) - }; - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - ctx_cfg.get_mut().set_rlimits(rlimits); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if cfg.vsock_config == VsockConfig::Disabled { + return -libc::ENODEV; + } + if cfg.set_port_map(port_map).is_err() { + return -libc::EINVAL; + } + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + + KRUN_SUCCESS } - KRUN_SUCCESS -} + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_rlimits(ctx_id: u32, c_rlimits: *const *const c_char) -> i32 { + let rlimits = if c_rlimits.is_null() { + return -libc::EINVAL; + } else { + let mut strvec = Vec::new(); -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_workdir(ctx_id: u32, c_workdir_path: *const c_char) -> i32 { - let workdir_path = match CStr::from_ptr(c_workdir_path).to_str() { - Ok(workdir) => workdir, - Err(_) => return -libc::EINVAL, - }; + let array: &[*const c_char] = slice::from_raw_parts(c_rlimits, MAX_ARGS); + for item in array.iter().take(MAX_ARGS) { + if item.is_null() { + break; + } else { + let s = match CStr::from_ptr(*item).to_str() { + Ok(s) => s, + Err(_) => return -libc::EINVAL, + }; + strvec.push(s); + } + } + + format!("\"{}\"", strvec.join(",")) + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - ctx_cfg.get_mut().set_workdir(workdir_path.to_string()); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + ctx_cfg.get_mut().set_rlimits(rlimits); + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} + KRUN_SUCCESS + } -unsafe fn collapse_str_array(array: &[*const c_char]) -> Result { - let mut strvec = Vec::new(); + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_workdir(ctx_id: u32, c_workdir_path: *const c_char) -> i32 { + let workdir_path = match CStr::from_ptr(c_workdir_path).to_str() { + Ok(workdir) => workdir, + Err(_) => return -libc::EINVAL, + }; - for item in array.iter().take(MAX_ARGS) { - if item.is_null() { - break; - } else { - let s = CStr::from_ptr(*item).to_str()?; - strvec.push(format!("\"{s}\"")); - } - } - - Ok(strvec.join(" ")) -} - -#[allow(clippy::format_collect)] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_exec( - ctx_id: u32, - c_exec_path: *const c_char, - c_argv: *const *const c_char, - c_envp: *const *const c_char, -) -> i32 { - let exec_path = match CStr::from_ptr(c_exec_path).to_str() { - Ok(path) => path, - Err(e) => { - debug!("Error parsing exec_path: {e:?}"); - return -libc::EINVAL; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + ctx_cfg.get_mut().set_workdir(workdir_path.to_string()); + } + Entry::Vacant(_) => return -libc::ENOENT, } - }; - let args = if !c_argv.is_null() { - let argv_array: &[*const c_char] = slice::from_raw_parts(c_argv, MAX_ARGS); - match collapse_str_array(argv_array) { - Ok(s) => s, - Err(e) => { - debug!("Error parsing args: {e:?}"); - return -libc::EINVAL; + KRUN_SUCCESS + } + + unsafe fn collapse_str_array(array: &[*const c_char]) -> Result { + let mut strvec = Vec::new(); + + for item in array.iter().take(MAX_ARGS) { + if item.is_null() { + break; + } else { + let s = CStr::from_ptr(*item).to_str()?; + strvec.push(format!("\"{s}\"")); } } - } else { - "".to_string() - }; - let env = if !c_envp.is_null() { - let envp_array: &[*const c_char] = slice::from_raw_parts(c_envp, MAX_ARGS); - match collapse_str_array(envp_array) { - Ok(s) => s, + Ok(strvec.join(" ")) + } + + #[allow(clippy::format_collect)] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_exec( + ctx_id: u32, + c_exec_path: *const c_char, + c_argv: *const *const c_char, + c_envp: *const *const c_char, + ) -> i32 { + let exec_path = match CStr::from_ptr(c_exec_path).to_str() { + Ok(path) => path, Err(e) => { - debug!("Error parsing args: {e:?}"); + debug!("Error parsing exec_path: {e:?}"); return -libc::EINVAL; } - } - } else { - env::vars() - .map(|(key, value)| format!(" {key}=\"{value}\"")) - .collect() - }; + }; + + let args = if !c_argv.is_null() { + let argv_array: &[*const c_char] = slice::from_raw_parts(c_argv, MAX_ARGS); + match collapse_str_array(argv_array) { + Ok(s) => s, + Err(e) => { + debug!("Error parsing args: {e:?}"); + return -libc::EINVAL; + } + } + } else { + "".to_string() + }; + + let env = if !c_envp.is_null() { + let envp_array: &[*const c_char] = slice::from_raw_parts(c_envp, MAX_ARGS); + match collapse_str_array(envp_array) { + Ok(s) => s, + Err(e) => { + debug!("Error parsing args: {e:?}"); + return -libc::EINVAL; + } + } + } else { + env::vars() + .map(|(key, value)| format!(" {key}=\"{value}\"")) + .collect() + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_exec_path(exec_path.to_string()); - cfg.set_env(env); - cfg.set_args(args); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_exec_path(exec_path.to_string()); + cfg.set_env(env); + cfg.set_args(args); + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + + KRUN_SUCCESS } - KRUN_SUCCESS -} + #[allow(clippy::format_collect)] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_env(ctx_id: u32, c_envp: *const *const c_char) -> i32 { + let env = if !c_envp.is_null() { + let envp_array: &[*const c_char] = slice::from_raw_parts(c_envp, MAX_ARGS); + match collapse_str_array(envp_array) { + Ok(s) => s, + Err(e) => { + debug!("Error parsing args: {e:?}"); + return -libc::EINVAL; + } + } + } else { + env::vars() + .map(|(key, value)| format!(" {key}=\"{value}\"")) + .collect() + }; -#[allow(clippy::format_collect)] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_env(ctx_id: u32, c_envp: *const *const c_char) -> i32 { - let env = if !c_envp.is_null() { - let envp_array: &[*const c_char] = slice::from_raw_parts(c_envp, MAX_ARGS); - match collapse_str_array(envp_array) { - Ok(s) => s, - Err(e) => { - debug!("Error parsing args: {e:?}"); - return -libc::EINVAL; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_env(env); } + Entry::Vacant(_) => return -libc::ENOENT, } - } else { - env::vars() - .map(|(key, value)| format!(" {key}=\"{value}\"")) - .collect() - }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_env(env); - } - Entry::Vacant(_) => return -libc::ENOENT, + KRUN_SUCCESS } - KRUN_SUCCESS -} + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + #[cfg(feature = "tee")] + pub unsafe extern "C" fn krun_set_tee_config_file( + ctx_id: u32, + c_filepath: *const c_char, + ) -> i32 { + let filepath = match CStr::from_ptr(c_filepath).to_str() { + Ok(f) => f, + Err(_) => return -libc::EINVAL, + }; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -#[cfg(feature = "tee")] -pub unsafe extern "C" fn krun_set_tee_config_file(ctx_id: u32, c_filepath: *const c_char) -> i32 { - let filepath = match CStr::from_ptr(c_filepath).to_str() { - Ok(f) => f, - Err(_) => return -libc::EINVAL, - }; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_tee_config_file(PathBuf::from(filepath.to_string())); + } + Entry::Vacant(_) => return -libc::ENOENT, + } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_tee_config_file(PathBuf::from(filepath.to_string())); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_vsock_port( - ctx_id: u32, - port: u32, - c_filepath: *const c_char, -) -> i32 { - krun_add_vsock_port2(ctx_id, port, c_filepath, false) -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_vsock_port2( - ctx_id: u32, - port: u32, - c_filepath: *const c_char, - listen: bool, -) -> i32 { - #[cfg(feature = "aws-nitro")] - if listen { - return -libc::EINVAL; + KRUN_SUCCESS } - let filepath = match CStr::from_ptr(c_filepath).to_str() { - Ok(f) => PathBuf::from(f.to_string()), - Err(_) => return -libc::EINVAL, - }; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_vsock_port( + ctx_id: u32, + port: u32, + c_filepath: *const c_char, + ) -> i32 { + krun_add_vsock_port2(ctx_id, port, c_filepath, false) + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_vsock_port2( + ctx_id: u32, + port: u32, + c_filepath: *const c_char, + listen: bool, + ) -> i32 { + #[cfg(feature = "aws-nitro")] + if listen { + return -libc::EINVAL; + } - if listen { - match filepath.try_exists() { - Ok(true) => return -libc::EEXIST, + let filepath = match CStr::from_ptr(c_filepath).to_str() { + Ok(f) => PathBuf::from(f.to_string()), Err(_) => return -libc::EINVAL, - _ => {} - } - } + }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - if cfg.vsock_config == VsockConfig::Disabled { - return -libc::ENODEV; + if listen { + match filepath.try_exists() { + Ok(true) => return -libc::EEXIST, + Err(_) => return -libc::EINVAL, + _ => {} } - cfg.add_vsock_port(port, filepath, listen); - } - Entry::Vacant(_) => return -libc::ENOENT, - } + } - KRUN_SUCCESS -} + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if cfg.vsock_config == VsockConfig::Disabled { + return -libc::ENODEV; + } + cfg.add_vsock_port(port, filepath, listen); + } + Entry::Vacant(_) => return -libc::ENOENT, + } -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_gpu_options(ctx_id: u32, virgl_flags: u32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_gpu_virgl_flags(virgl_flags); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_gpu_options2( - ctx_id: u32, - virgl_flags: u32, - shm_size: u64, -) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_gpu_virgl_flags(virgl_flags); - cfg.set_gpu_shm_size(shm_size.try_into().unwrap()); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[cfg(not(feature = "gpu"))] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub extern "C" fn krun_set_display_backend( - _ctx_id: u32, - _features: u32, - _vtable: *const c_void, - _vtable_size: usize, -) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "gpu")] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub extern "C" fn krun_set_display_backend( - ctx_id: u32, - vtable: *const c_void, - vtable_size: usize, -) -> i32 { - if vtable_size < size_of::() { - return -libc::EINVAL; - } - - // SAFETY: We have checked the vtable size is fine, otherwise we have to trust the user. Just - // to be extra careful, this uses read_unaligned, but we could probably get away with ptr::read. - let display_backend: DisplayBackend = - unsafe { std::ptr::read_unaligned(vtable as *const DisplayBackend) }; - - if !display_backend.verify() { - return -libc::EINVAL; - } - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.display_backend = Some(display_backend); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[cfg(not(feature = "input"))] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub extern "C" fn krun_add_input_device( - _ctx_id: u32, - _config_backend: *const c_void, - _config_backend_size: size_t, - _event_provider_backend: *const c_void, - _event_provider_backend_size: size_t, -) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "input")] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub extern "C" fn krun_add_input_device_fd(ctx_id: u32, input_fd: i32) -> i32 { - use devices::virtio::input::passthrough::PassthroughInputBackend; - use krun_input::{IntoInputConfig, IntoInputEvents}; - - if input_fd < 0 { - return -libc::EINVAL; - } - // TODO: currently we let the fd (and it's Box allocation) live forever, we should eventually fix - // this - let input_fd = unsafe { - // SAFETY: The user provided fd should be valid. Its lifetime is 'static because it will - // exist until libkrun _exits the process - BorrowedFd::borrow_raw(input_fd) - }; - let borrowed_fd: &'static BorrowedFd<'static> = Box::leak(Box::new(input_fd)); + KRUN_SUCCESS + } - let config_backend = PassthroughInputBackend::into_input_config(Some(borrowed_fd)); - let events_backend = PassthroughInputBackend::into_input_events(Some(borrowed_fd)); + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_gpu_options(ctx_id: u32, virgl_flags: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_gpu_virgl_flags(virgl_flags); + } + Entry::Vacant(_) => return -libc::ENOENT, + } - with_cfg(ctx_id, |cfg| { - cfg.vmr - .input_backends - .push((config_backend, events_backend)); KRUN_SUCCESS - }) -} - -#[cfg(feature = "input")] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_input_device( - ctx_id: u32, - config_backend: *const InputConfigBackend<'static>, - config_backend_size: size_t, - event_provider_backend: *const InputEventProviderBackend<'static>, - event_provider_backend_size: size_t, -) -> i32 { - if config_backend.is_null() || event_provider_backend.is_null() { - return -libc::EINVAL; } - if config_backend_size < size_of::() - || event_provider_backend_size < size_of::() - { - return -libc::EINVAL; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_gpu_options2( + ctx_id: u32, + virgl_flags: u32, + shm_size: u64, + ) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_gpu_virgl_flags(virgl_flags); + cfg.set_gpu_shm_size(shm_size.try_into().unwrap()); + } + Entry::Vacant(_) => return -libc::ENOENT, + } + + KRUN_SUCCESS } - let config_backend = unsafe { *config_backend }; - let events_backend = unsafe { *event_provider_backend }; + #[cfg(not(feature = "gpu"))] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub extern "C" fn krun_set_display_backend( + _ctx_id: u32, + _features: u32, + _vtable: *const c_void, + _vtable_size: usize, + ) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "gpu")] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub extern "C" fn krun_set_display_backend( + ctx_id: u32, + vtable: *const c_void, + vtable_size: usize, + ) -> i32 { + if vtable_size < size_of::() { + return -libc::EINVAL; + } - if !config_backend.verify() || !events_backend.verify() { - return -libc::EINVAL; - } + // SAFETY: We have checked the vtable size is fine, otherwise we have to trust the user. Just + // to be extra careful, this uses read_unaligned, but we could probably get away with ptr::read. + let display_backend: DisplayBackend = + unsafe { std::ptr::read_unaligned(vtable as *const DisplayBackend) }; - with_cfg(ctx_id, |cfg| { - cfg.vmr - .input_backends - .push((config_backend, events_backend)); - KRUN_SUCCESS - }) -} - -#[cfg(not(feature = "input"))] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_input_device_fd(_ctx_id: u32, _input_fd: i32) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "gpu")] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_display(ctx_id: u32, width: u32, height: u32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - if cfg.vmr.displays.len() >= MAX_DISPLAYS { - return -libc::ENOMEM; - } - - cfg.vmr.displays.push(DisplayInfo::new(width, height)); - (cfg.vmr.displays.len() - 1) as i32 - } - Entry::Vacant(_) => -libc::ENOENT, - } -} - -#[cfg(not(feature = "gpu"))] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_display(_ctx_id: u32, _width: u32, _height: u32) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "gpu")] -#[no_mangle] -pub extern "C" fn krun_display_set_refresh_rate( - ctx_id: u32, - display_id: u32, - refresh_rate: u32, -) -> i32 { - with_cfg(ctx_id, |cfg| { - let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + if !display_backend.verify() { return -libc::EINVAL; - }; + } - let DisplayInfoEdid::Generated(ref mut edid_params) = display_info.edid else { - return -libc::EALREADY; - }; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.display_backend = Some(display_backend); + } + Entry::Vacant(_) => return -libc::ENOENT, + } - edid_params.refresh_rate = refresh_rate; KRUN_SUCCESS - }) -} - -#[cfg(not(feature = "gpu"))] -#[no_mangle] -pub extern "C" fn krun_display_set_refresh_rate( - _ctx_id: u32, - _display_id: u32, - _refresh_rate: u32, -) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "gpu")] -#[no_mangle] -#[allow(clippy::missing_safety_doc)] -pub unsafe extern "C" fn krun_display_set_edid( - ctx_id: u32, - display_id: u32, - edid: *const u8, - size: size_t, -) -> i32 { - with_cfg(ctx_id, |cfg| { - let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + } + + #[cfg(not(feature = "input"))] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub extern "C" fn krun_add_input_device( + _ctx_id: u32, + _config_backend: *const c_void, + _config_backend_size: size_t, + _event_provider_backend: *const c_void, + _event_provider_backend_size: size_t, + ) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "input")] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub extern "C" fn krun_add_input_device_fd(ctx_id: u32, input_fd: i32) -> i32 { + use devices::virtio::input::passthrough::PassthroughInputBackend; + use krun_input::{IntoInputConfig, IntoInputEvents}; + + if input_fd < 0 { return -libc::EINVAL; + } + // TODO: currently we let the fd (and it's Box allocation) live forever, we should eventually fix + // this + let input_fd = unsafe { + // SAFETY: The user provided fd should be valid. Its lifetime is 'static because it will + // exist until libkrun _exits the process + BorrowedFd::borrow_raw(input_fd) }; + let borrowed_fd: &'static BorrowedFd<'static> = Box::leak(Box::new(input_fd)); + + let config_backend = PassthroughInputBackend::into_input_config(Some(borrowed_fd)); + let events_backend = PassthroughInputBackend::into_input_events(Some(borrowed_fd)); + + with_cfg(ctx_id, |cfg| { + cfg.vmr + .input_backends + .push((config_backend, events_backend)); + KRUN_SUCCESS + }) + } - if edid.is_null() { + #[cfg(feature = "input")] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_input_device( + ctx_id: u32, + config_backend: *const InputConfigBackend<'static>, + config_backend_size: size_t, + event_provider_backend: *const InputEventProviderBackend<'static>, + event_provider_backend_size: size_t, + ) -> i32 { + if config_backend.is_null() || event_provider_backend.is_null() { return -libc::EINVAL; } - let blob = unsafe { slice::from_raw_parts(edid, size) }; + if config_backend_size < size_of::() + || event_provider_backend_size < size_of::() + { + return -libc::EINVAL; + } - display_info.edid = DisplayInfoEdid::Provided(Box::from(blob)); - KRUN_SUCCESS - }) -} - -#[cfg(not(feature = "gpu"))] -#[no_mangle] -#[allow(clippy::missing_safety_doc)] -pub unsafe extern "C" fn krun_display_set_edid( - _ctx_id: u32, - _display_id: u32, - _edid: *const u8, - _size: size_t, -) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "gpu")] -#[no_mangle] -pub extern "C" fn krun_display_set_physical_size( - ctx_id: u32, - display_id: u32, - width_mm: u16, - height_mm: u16, -) -> i32 { - with_cfg(ctx_id, |cfg| { - let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + let config_backend = unsafe { *config_backend }; + let events_backend = unsafe { *event_provider_backend }; + + if !config_backend.verify() || !events_backend.verify() { return -libc::EINVAL; - }; - let DisplayInfoEdid::Generated(ref mut edid_params) = display_info.edid else { - return -libc::EALREADY; - }; - edid_params.physical_size = PhysicalSize::DimensionsMillimeters(width_mm, height_mm); + } + + with_cfg(ctx_id, |cfg| { + cfg.vmr + .input_backends + .push((config_backend, events_backend)); + KRUN_SUCCESS + }) + } + + #[cfg(not(feature = "input"))] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_input_device_fd(_ctx_id: u32, _input_fd: i32) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "gpu")] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_display(ctx_id: u32, width: u32, height: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if cfg.vmr.displays.len() >= MAX_DISPLAYS { + return -libc::ENOMEM; + } + + cfg.vmr.displays.push(DisplayInfo::new(width, height)); + (cfg.vmr.displays.len() - 1) as i32 + } + Entry::Vacant(_) => -libc::ENOENT, + } + } + + #[cfg(not(feature = "gpu"))] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_display(_ctx_id: u32, _width: u32, _height: u32) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "gpu")] + #[no_mangle] + pub extern "C" fn krun_display_set_refresh_rate( + ctx_id: u32, + display_id: u32, + refresh_rate: u32, + ) -> i32 { + with_cfg(ctx_id, |cfg| { + let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + return -libc::EINVAL; + }; + + let DisplayInfoEdid::Generated(ref mut edid_params) = display_info.edid else { + return -libc::EALREADY; + }; + + edid_params.refresh_rate = refresh_rate; + KRUN_SUCCESS + }) + } + + #[cfg(not(feature = "gpu"))] + #[no_mangle] + pub extern "C" fn krun_display_set_refresh_rate( + _ctx_id: u32, + _display_id: u32, + _refresh_rate: u32, + ) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "gpu")] + #[no_mangle] + #[allow(clippy::missing_safety_doc)] + pub unsafe extern "C" fn krun_display_set_edid( + ctx_id: u32, + display_id: u32, + edid: *const u8, + size: size_t, + ) -> i32 { + with_cfg(ctx_id, |cfg| { + let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + return -libc::EINVAL; + }; + + if edid.is_null() { + return -libc::EINVAL; + } + + let blob = unsafe { slice::from_raw_parts(edid, size) }; + + display_info.edid = DisplayInfoEdid::Provided(Box::from(blob)); + KRUN_SUCCESS + }) + } + + #[cfg(not(feature = "gpu"))] + #[no_mangle] + #[allow(clippy::missing_safety_doc)] + pub unsafe extern "C" fn krun_display_set_edid( + _ctx_id: u32, + _display_id: u32, + _edid: *const u8, + _size: size_t, + ) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "gpu")] + #[no_mangle] + pub extern "C" fn krun_display_set_physical_size( + ctx_id: u32, + display_id: u32, + width_mm: u16, + height_mm: u16, + ) -> i32 { + with_cfg(ctx_id, |cfg| { + let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + return -libc::EINVAL; + }; + let DisplayInfoEdid::Generated(ref mut edid_params) = display_info.edid else { + return -libc::EALREADY; + }; + edid_params.physical_size = PhysicalSize::DimensionsMillimeters(width_mm, height_mm); + KRUN_SUCCESS + }) + } + + #[cfg(not(feature = "gpu"))] + #[no_mangle] + pub extern "C" fn krun_display_set_physical_size( + _ctx_id: u32, + _display_id: u32, + _width_mm: u16, + _height_mm: u16, + ) -> i32 { + -libc::ENOTSUP + } + + #[cfg(feature = "gpu")] + #[no_mangle] + #[allow(clippy::missing_safety_doc)] + pub extern "C" fn krun_display_set_dpi(ctx_id: u32, display_id: u32, dpi: u32) -> i32 { + with_cfg(ctx_id, |cfg| { + let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { + return -libc::EINVAL; + }; + let DisplayInfoEdid::Generated(ref mut edid_params) = display_info.edid else { + return -libc::EINVAL; + }; + edid_params.physical_size = PhysicalSize::Dpi(dpi); + KRUN_SUCCESS + }) + } + + #[cfg(not(feature = "gpu"))] + #[no_mangle] + pub extern "C" fn krun_display_set_dpi(_ctx_id: u32, _display_id: u32, _dpi: u32) -> i32 { + -libc::ENOTSUP + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_snd_device(ctx_id: u32, enable: bool) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.enable_snd = enable; + } + Entry::Vacant(_) => return -libc::ENOENT, + } + KRUN_SUCCESS - }) -} - -#[cfg(not(feature = "gpu"))] -#[no_mangle] -pub extern "C" fn krun_display_set_physical_size( - _ctx_id: u32, - _display_id: u32, - _width_mm: u16, - _height_mm: u16, -) -> i32 { - -libc::ENOTSUP -} - -#[cfg(feature = "gpu")] -#[no_mangle] -#[allow(clippy::missing_safety_doc)] -pub extern "C" fn krun_display_set_dpi(ctx_id: u32, display_id: u32, dpi: u32) -> i32 { - with_cfg(ctx_id, |cfg| { - let Some(display_info) = cfg.vmr.displays.get_mut(display_id as usize) else { - return -libc::EINVAL; + } + + #[allow(unused_assignments)] + #[no_mangle] + pub extern "C" fn krun_get_shutdown_eventfd(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if let Some(efd) = cfg.shutdown_efd.as_ref() { + #[cfg(target_os = "macos")] + return efd.get_write_fd(); + #[cfg(target_os = "linux")] + return efd.as_raw_fd(); + } else { + -libc::EINVAL + } + } + Entry::Vacant(_) => -libc::ENOENT, + } + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_console_output( + ctx_id: u32, + c_filepath: *const c_char, + ) -> i32 { + let filepath = match CStr::from_ptr(c_filepath).to_str() { + Ok(f) => f, + Err(_) => return -libc::EINVAL, }; - let DisplayInfoEdid::Generated(ref mut edid_params) = display_info.edid else { - return -libc::EINVAL; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if cfg.console_output.is_some() { + -libc::EINVAL + } else { + cfg.console_output = Some(PathBuf::from(filepath.to_string())); + KRUN_SUCCESS + } + } + Entry::Vacant(_) => -libc::ENOENT, + } + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_nested_virt(ctx_id: u32, enabled: bool) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.nested_enabled = enabled; + KRUN_SUCCESS + } + Entry::Vacant(_) => -libc::ENOENT, + } + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_check_nested_virt() -> i32 { + #[cfg(target_os = "macos")] + match hvf::check_nested_virt() { + Ok(supp) => supp as i32, + Err(_) => -libc::EINVAL, + } + + #[cfg(target_os = "linux")] + { + let paths = [ + "/sys/module/kvm_intel/parameters/nested", + "/sys/module/kvm_amd/parameters/nested", + ]; + if paths.iter().any(|path| { + std::fs::read_to_string(path).is_ok_and(|contents| { + let val = contents.trim(); + val == "1" || val.eq_ignore_ascii_case("Y") + }) + }) { + 1 + } else { + 0 + } + } + + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + -libc::EOPNOTSUPP + } + + const KRUN_FEATURE_NET: u64 = 0; + const KRUN_FEATURE_BLK: u64 = 1; + const KRUN_FEATURE_GPU: u64 = 2; + const KRUN_FEATURE_SND: u64 = 3; + const KRUN_FEATURE_INPUT: u64 = 4; + const KRUN_FEATURE_TEE: u64 = 6; + const KRUN_FEATURE_AMD_SEV: u64 = 7; + const KRUN_FEATURE_INTEL_TDX: u64 = 8; + const KRUN_FEATURE_AWS_NITRO: u64 = 9; + const KRUN_FEATURE_VIRGL_RESOURCE_MAP2: u64 = 10; + + #[no_mangle] + pub extern "C" fn krun_has_feature(feature: u64) -> c_int { + let supported = match feature { + KRUN_FEATURE_NET => cfg!(feature = "net"), + KRUN_FEATURE_BLK => cfg!(feature = "blk"), + KRUN_FEATURE_GPU => cfg!(feature = "gpu"), + KRUN_FEATURE_SND => cfg!(feature = "snd"), + KRUN_FEATURE_INPUT => cfg!(feature = "input"), + KRUN_FEATURE_TEE => cfg!(feature = "tee"), + KRUN_FEATURE_AMD_SEV => cfg!(feature = "amd-sev"), + KRUN_FEATURE_INTEL_TDX => cfg!(feature = "tdx"), + KRUN_FEATURE_AWS_NITRO => cfg!(feature = "aws-nitro"), + KRUN_FEATURE_VIRGL_RESOURCE_MAP2 => cfg!(feature = "virgl_resource_map2"), + _ => return -libc::EINVAL, }; - edid_params.physical_size = PhysicalSize::Dpi(dpi); - KRUN_SUCCESS - }) -} - -#[cfg(not(feature = "gpu"))] -#[no_mangle] -pub extern "C" fn krun_display_set_dpi(_ctx_id: u32, _display_id: u32, _dpi: u32) -> i32 { - -libc::ENOTSUP -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_snd_device(ctx_id: u32, enable: bool) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.enable_snd = enable; - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[allow(unused_assignments)] -#[no_mangle] -pub extern "C" fn krun_get_shutdown_eventfd(ctx_id: u32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - if let Some(efd) = cfg.shutdown_efd.as_ref() { - #[cfg(target_os = "macos")] - return efd.get_write_fd(); - #[cfg(target_os = "linux")] - return efd.as_raw_fd(); + + supported as c_int + } + + /// Gets the maximum number of vCPUs supported by the hypervisor. + /// + /// Returns the maximum number of vCPUs that can be created by this hypervisor, + /// or a negative error code on failure. + #[cfg(any(target_os = "macos", target_os = "linux"))] + #[no_mangle] + pub extern "C" fn krun_get_max_vcpus() -> i32 { + #[cfg(target_os = "macos")] + { + use hvf::bindings::{hv_vm_get_max_vcpu_count, HV_SUCCESS}; + let mut max_vcpu_count: u32 = 0; + let ret = unsafe { hv_vm_get_max_vcpu_count(&mut max_vcpu_count as *mut u32) }; + if ret == HV_SUCCESS { + max_vcpu_count as i32 } else { + error!("Error retrieving max vcpu count: {ret:?}"); -libc::EINVAL } } - Entry::Vacant(_) => -libc::ENOENT, - } -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_console_output(ctx_id: u32, c_filepath: *const c_char) -> i32 { - let filepath = match CStr::from_ptr(c_filepath).to_str() { - Ok(f) => f, - Err(_) => return -libc::EINVAL, - }; + #[cfg(target_os = "linux")] + { + use kvm_ioctls::Kvm; + match Kvm::new() { + Ok(kvm) => kvm.get_max_vcpus() as i32, + Err(e) => { + error!("Error retrieving max vcpu count: {e:?}"); + -libc::EINVAL + } + } + } + } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - if cfg.console_output.is_some() { - -libc::EINVAL - } else { - cfg.console_output = Some(PathBuf::from(filepath.to_string())); + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub extern "C" fn krun_split_irqchip(ctx_id: u32, enable: bool) -> i32 { + if enable && !cfg!(target_arch = "x86_64") { + return -libc::EINVAL; + } + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.split_irqchip = enable; KRUN_SUCCESS } + Entry::Vacant(_) => -libc::ENOENT, } - Entry::Vacant(_) => -libc::ENOENT, } -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_nested_virt(ctx_id: u32, enabled: bool) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.nested_enabled = enabled; - KRUN_SUCCESS + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_smbios_oem_strings( + ctx_id: u32, + oem_strings: *const *const c_char, + ) -> i32 { + if oem_strings.is_null() { + return -libc::EINVAL; } - Entry::Vacant(_) => -libc::ENOENT, + + let cstr_ptr_slice = slice::from_raw_parts(oem_strings, MAX_ARGS); + + let mut oem_strings = Vec::new(); + + for cstr_ptr in cstr_ptr_slice.iter().take_while(|p| !p.is_null()) { + let Ok(s) = CStr::from_ptr(*cstr_ptr).to_str() else { + return -libc::EINVAL; + }; + oem_strings.push(s.to_string()); + } + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + ctx_cfg.get_mut().vmr.smbios_oem_strings = + (!oem_strings.is_empty()).then_some(oem_strings) + } + Entry::Vacant(_) => return -libc::ENOENT, + } + + KRUN_SUCCESS } -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_check_nested_virt() -> i32 { - #[cfg(target_os = "macos")] - match hvf::check_nested_virt() { - Ok(supp) => supp as i32, - Err(_) => -libc::EINVAL, + #[cfg(feature = "net")] + fn create_virtio_net( + ctx_cfg: &mut ContextConfig, + backend: VirtioNetBackend, + mac: [u8; 6], + features: u32, + ) { + let network_interface_config = NetworkInterfaceConfig { + iface_id: format!("eth{}", ctx_cfg.net_index), + backend, + mac, + features, + }; + ctx_cfg.net_index += 1; + ctx_cfg + .vmr + .add_network_interface(network_interface_config) + .expect("Failed to create network interface"); } - #[cfg(target_os = "linux")] - { - let paths = [ - "/sys/module/kvm_intel/parameters/nested", - "/sys/module/kvm_amd/parameters/nested", - ]; - if paths.iter().any(|path| { - std::fs::read_to_string(path).is_ok_and(|contents| { - let val = contents.trim(); - val == "1" || val.eq_ignore_ascii_case("Y") - }) - }) { - 1 - } else { - 0 - } - } - - #[cfg(not(any(target_os = "macos", target_os = "linux")))] - -libc::EOPNOTSUPP -} - -const KRUN_FEATURE_NET: u64 = 0; -const KRUN_FEATURE_BLK: u64 = 1; -const KRUN_FEATURE_GPU: u64 = 2; -const KRUN_FEATURE_SND: u64 = 3; -const KRUN_FEATURE_INPUT: u64 = 4; -const KRUN_FEATURE_TEE: u64 = 6; -const KRUN_FEATURE_AMD_SEV: u64 = 7; -const KRUN_FEATURE_INTEL_TDX: u64 = 8; -const KRUN_FEATURE_AWS_NITRO: u64 = 9; -const KRUN_FEATURE_VIRGL_RESOURCE_MAP2: u64 = 10; - -#[no_mangle] -pub extern "C" fn krun_has_feature(feature: u64) -> c_int { - let supported = match feature { - KRUN_FEATURE_NET => cfg!(feature = "net"), - KRUN_FEATURE_BLK => cfg!(feature = "blk"), - KRUN_FEATURE_GPU => cfg!(feature = "gpu"), - KRUN_FEATURE_SND => cfg!(feature = "snd"), - KRUN_FEATURE_INPUT => cfg!(feature = "input"), - KRUN_FEATURE_TEE => cfg!(feature = "tee"), - KRUN_FEATURE_AMD_SEV => cfg!(feature = "amd-sev"), - KRUN_FEATURE_INTEL_TDX => cfg!(feature = "tdx"), - KRUN_FEATURE_AWS_NITRO => cfg!(feature = "aws-nitro"), - KRUN_FEATURE_VIRGL_RESOURCE_MAP2 => cfg!(feature = "virgl_resource_map2"), - _ => return -libc::EINVAL, - }; + #[cfg(all(target_arch = "x86_64", not(feature = "tee")))] + fn map_kernel(ctx_id: u32, kernel_path: &PathBuf) -> i32 { + let file = match File::options().read(true).write(false).open(kernel_path) { + Ok(file) => file, + Err(err) => { + error!("Error opening external kernel: {err}"); + return -libc::EINVAL; + } + }; - supported as c_int -} + let kernel_size = file.metadata().unwrap().len(); + + let kernel_host_addr = unsafe { + libc::mmap( + std::ptr::null_mut(), + kernel_size as usize, + libc::PROT_READ, + libc::MAP_SHARED, + file.as_raw_fd(), + 0_i64, + ) + }; + if std::ptr::eq(kernel_host_addr, libc::MAP_FAILED) { + error!("Can't load kernel into process map"); + return -libc::EINVAL; + } -/// Gets the maximum number of vCPUs supported by the hypervisor. -/// -/// Returns the maximum number of vCPUs that can be created by this hypervisor, -/// or a negative error code on failure. -#[cfg(any(target_os = "macos", target_os = "linux"))] -#[no_mangle] -pub extern "C" fn krun_get_max_vcpus() -> i32 { - #[cfg(target_os = "macos")] - { - use hvf::bindings::{hv_vm_get_max_vcpu_count, HV_SUCCESS}; - let mut max_vcpu_count: u32 = 0; - let ret = unsafe { hv_vm_get_max_vcpu_count(&mut max_vcpu_count as *mut u32) }; - if ret == HV_SUCCESS { - max_vcpu_count as i32 + let kernel_bundle = KernelBundle { + host_addr: kernel_host_addr as u64, + guest_addr: 0x8000_0000, + entry_addr: 0x8000_0000, + size: kernel_size as usize, + }; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => ctx_cfg + .get_mut() + .vmr + .set_kernel_bundle(kernel_bundle) + .unwrap(), + Entry::Vacant(_) => return -libc::ENOENT, + } + + KRUN_SUCCESS + } + + #[cfg(feature = "tee")] + #[allow(clippy::format_collect)] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_kernel(_ctx_id: u32, _c_kernel_path: *const c_char) -> i32 { + -libc::EOPNOTSUPP + } + + #[cfg(not(feature = "tee"))] + #[allow(clippy::format_collect)] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_kernel( + ctx_id: u32, + c_kernel_path: *const c_char, + kernel_format: u32, + c_initramfs_path: *const c_char, + c_cmdline: *const c_char, + ) -> i32 { + let path = match CStr::from_ptr(c_kernel_path).to_str() { + Ok(path) => PathBuf::from(path), + Err(e) => { + error!("Error parsing kernel_path: {e:?}"); + return -libc::EINVAL; + } + }; + + let format = match kernel_format { + // For raw kernels in x86_64, we map the kernel into the + // process and treat it as a bundled kernel. + #[cfg(all(target_arch = "x86_64", not(feature = "tee")))] + 0 => return map_kernel(ctx_id, &path), + #[cfg(target_arch = "aarch64")] + 0 => KernelFormat::Raw, + 1 => KernelFormat::Elf, + 2 => KernelFormat::PeGz, + 3 => KernelFormat::ImageBz2, + 4 => KernelFormat::ImageGz, + 5 => KernelFormat::ImageZstd, + _ => { + return -libc::EINVAL; + } + }; + + let (initramfs_path, initramfs_size) = if !c_initramfs_path.is_null() { + match CStr::from_ptr(c_initramfs_path).to_str() { + Ok(path) => { + let path = PathBuf::from(path); + let size = match std::fs::metadata(&path) { + Ok(metadata) => metadata.len(), + Err(e) => { + error!("Can't read initramfs metadata: {e:?}"); + return -libc::EINVAL; + } + }; + (Some(path), size) + } + Err(e) => { + error!("Error parsing initramfs path: {e:?}"); + return -libc::EINVAL; + } + } + } else { + (None, 0) + }; + + let cmdline = if !c_cmdline.is_null() { + match CStr::from_ptr(c_cmdline).to_str() { + Ok(cmdline) => Some(cmdline.to_string()), + Err(e) => { + error!("Error parsing kernel cmdline: {e:?}"); + return -libc::EINVAL; + } + } } else { - error!("Error retrieving max vcpu count: {ret:?}"); - -libc::EINVAL + None + }; + + let external_kernel = ExternalKernel { + path, + format, + initramfs_path, + initramfs_size, + cmdline, + }; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + ctx_cfg.get_mut().vmr.set_external_kernel(external_kernel) + } + Entry::Vacant(_) => return -libc::ENOENT, } + + KRUN_SUCCESS } - #[cfg(target_os = "linux")] - { - use kvm_ioctls::Kvm; - match Kvm::new() { - Ok(kvm) => kvm.get_max_vcpus() as i32, + #[cfg(not(feature = "tee"))] + #[allow(clippy::format_collect)] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_firmware(ctx_id: u32, c_firmware_path: *const c_char) -> i32 { + let path = match CStr::from_ptr(c_firmware_path).to_str() { + Ok(path) => PathBuf::from(path), Err(e) => { - error!("Error retrieving max vcpu count: {e:?}"); - -libc::EINVAL + error!("Error parsing firmware_path: {e:?}"); + return -libc::EINVAL; } - } - } -} + }; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub extern "C" fn krun_split_irqchip(ctx_id: u32, enable: bool) -> i32 { - if enable && !cfg!(target_arch = "x86_64") { - return -libc::EINVAL; - } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.split_irqchip = enable; - KRUN_SUCCESS + let firmware_config = FirmwareConfig { path }; + + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + ctx_cfg.get_mut().vmr.set_firmware_config(firmware_config) + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => -libc::ENOENT, - } -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_smbios_oem_strings( - ctx_id: u32, - oem_strings: *const *const c_char, -) -> i32 { - if oem_strings.is_null() { - return -libc::EINVAL; + KRUN_SUCCESS } - let cstr_ptr_slice = slice::from_raw_parts(oem_strings, MAX_ARGS); - - let mut oem_strings = Vec::new(); - - for cstr_ptr in cstr_ptr_slice.iter().take_while(|p| !p.is_null()) { - let Ok(s) = CStr::from_ptr(*cstr_ptr).to_str() else { - return -libc::EINVAL; + unsafe fn load_krunfw_payload( + krunfw: &KrunfwBindings, + vmr: &mut VmResources, + ) -> Result<(), libloading::Error> { + let mut kernel_guest_addr: u64 = 0; + let mut kernel_entry_addr: u64 = 0; + let mut kernel_size: usize = 0; + let kernel_host_addr = unsafe { + (krunfw.get_kernel)( + &mut kernel_guest_addr as *mut u64, + &mut kernel_entry_addr as *mut u64, + &mut kernel_size as *mut usize, + ) }; - oem_strings.push(s.to_string()); - } + let kernel_bundle = KernelBundle { + host_addr: kernel_host_addr as u64, + guest_addr: kernel_guest_addr, + entry_addr: kernel_entry_addr, + size: kernel_size, + }; + vmr.set_kernel_bundle(kernel_bundle).unwrap(); + + #[cfg(feature = "tee")] + { + let mut qboot_size: usize = 0; + let qboot_host_addr = unsafe { (krunfw.get_qboot)(&mut qboot_size as *mut usize) }; + let qboot_bundle = QbootBundle { + host_addr: qboot_host_addr as u64, + size: qboot_size, + }; + vmr.set_qboot_bundle(qboot_bundle).unwrap(); - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - ctx_cfg.get_mut().vmr.smbios_oem_strings = - (!oem_strings.is_empty()).then_some(oem_strings) + let mut initrd_size: usize = 0; + let initrd_host_addr = unsafe { (krunfw.get_initrd)(&mut initrd_size as *mut usize) }; + let initrd_bundle = InitrdBundle { + host_addr: initrd_host_addr as u64, + size: initrd_size, + }; + vmr.set_initrd_bundle(initrd_bundle).unwrap(); } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} + Ok(()) + } -#[cfg(feature = "net")] -fn create_virtio_net( - ctx_cfg: &mut ContextConfig, - backend: VirtioNetBackend, - mac: [u8; 6], - features: u32, -) { - let network_interface_config = NetworkInterfaceConfig { - iface_id: format!("eth{}", ctx_cfg.net_index), - backend, - mac, - features, - }; - ctx_cfg.net_index += 1; - ctx_cfg - .vmr - .add_network_interface(network_interface_config) - .expect("Failed to create network interface"); -} - -#[cfg(all(target_arch = "x86_64", not(feature = "tee")))] -fn map_kernel(ctx_id: u32, kernel_path: &PathBuf) -> i32 { - let file = match File::options().read(true).write(false).open(kernel_path) { - Ok(file) => file, - Err(err) => { - error!("Error opening external kernel: {err}"); - return -libc::EINVAL; + #[no_mangle] + pub extern "C" fn krun_setuid(ctx_id: u32, uid: libc::uid_t) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_vmm_uid(uid); + } + Entry::Vacant(_) => return -libc::ENOENT, } - }; - let kernel_size = file.metadata().unwrap().len(); - - let kernel_host_addr = unsafe { - libc::mmap( - std::ptr::null_mut(), - kernel_size as usize, - libc::PROT_READ, - libc::MAP_SHARED, - file.as_raw_fd(), - 0_i64, - ) - }; - if std::ptr::eq(kernel_host_addr, libc::MAP_FAILED) { - error!("Can't load kernel into process map"); - return -libc::EINVAL; + KRUN_SUCCESS } - let kernel_bundle = KernelBundle { - host_addr: kernel_host_addr as u64, - guest_addr: 0x8000_0000, - entry_addr: 0x8000_0000, - size: kernel_size as usize, - }; - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => ctx_cfg - .get_mut() - .vmr - .set_kernel_bundle(kernel_bundle) - .unwrap(), - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[cfg(feature = "tee")] -#[allow(clippy::format_collect)] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_kernel(_ctx_id: u32, _c_kernel_path: *const c_char) -> i32 { - -libc::EOPNOTSUPP -} - -#[cfg(not(feature = "tee"))] -#[allow(clippy::format_collect)] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_kernel( - ctx_id: u32, - c_kernel_path: *const c_char, - kernel_format: u32, - c_initramfs_path: *const c_char, - c_cmdline: *const c_char, -) -> i32 { - let path = match CStr::from_ptr(c_kernel_path).to_str() { - Ok(path) => PathBuf::from(path), - Err(e) => { - error!("Error parsing kernel_path: {e:?}"); - return -libc::EINVAL; + #[no_mangle] + pub extern "C" fn krun_setgid(ctx_id: u32, gid: libc::gid_t) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.set_vmm_gid(gid); + } + Entry::Vacant(_) => return -libc::ENOENT, } - }; - let format = match kernel_format { - // For raw kernels in x86_64, we map the kernel into the - // process and treat it as a bundled kernel. - #[cfg(all(target_arch = "x86_64", not(feature = "tee")))] - 0 => return map_kernel(ctx_id, &path), - #[cfg(target_arch = "aarch64")] - 0 => KernelFormat::Raw, - 1 => KernelFormat::Elf, - 2 => KernelFormat::PeGz, - 3 => KernelFormat::ImageBz2, - 4 => KernelFormat::ImageGz, - 5 => KernelFormat::ImageZstd, - _ => { - return -libc::EINVAL; - } - }; + KRUN_SUCCESS + } - let (initramfs_path, initramfs_size) = if !c_initramfs_path.is_null() { - match CStr::from_ptr(c_initramfs_path).to_str() { - Ok(path) => { - let path = PathBuf::from(path); - let size = match std::fs::metadata(&path) { - Ok(metadata) => metadata.len(), - Err(e) => { - error!("Can't read initramfs metadata: {e:?}"); - return -libc::EINVAL; - } - }; - (Some(path), size) - } + #[cfg(all(feature = "blk", not(feature = "tee")))] + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_root_disk_remount( + ctx_id: u32, + c_device: *const c_char, + c_fstype: *const c_char, + c_options: *const c_char, + ) -> i32 { + let device = match CStr::from_ptr(c_device).to_str() { + Ok(device) => device.to_string(), Err(e) => { - error!("Error parsing initramfs path: {e:?}"); + error!("Error parsing device path: {e:?}"); return -libc::EINVAL; } - } - } else { - (None, 0) - }; + }; - let cmdline = if !c_cmdline.is_null() { - match CStr::from_ptr(c_cmdline).to_str() { - Ok(cmdline) => Some(cmdline.to_string()), - Err(e) => { - error!("Error parsing kernel cmdline: {e:?}"); - return -libc::EINVAL; + let fstype = if !c_fstype.is_null() { + match CStr::from_ptr(c_fstype).to_str() { + Ok(fstype) => { + if fstype == "auto" { + None + } else { + Some(fstype.to_string()) + } + } + Err(e) => { + error!("Error parsing fstype: {e:?}"); + return -libc::EINVAL; + } } - } - } else { - None - }; - - let external_kernel = ExternalKernel { - path, - format, - initramfs_path, - initramfs_size, - cmdline, - }; - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => ctx_cfg.get_mut().vmr.set_external_kernel(external_kernel), - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} + } else { + None + }; -#[cfg(not(feature = "tee"))] -#[allow(clippy::format_collect)] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_firmware(ctx_id: u32, c_firmware_path: *const c_char) -> i32 { - let path = match CStr::from_ptr(c_firmware_path).to_str() { - Ok(path) => PathBuf::from(path), - Err(e) => { - error!("Error parsing firmware_path: {e:?}"); - return -libc::EINVAL; - } - }; + let options = if !c_options.is_null() { + match CStr::from_ptr(c_options).to_str() { + Ok(options) => Some(options.to_string()), + Err(e) => { + error!("Error parsing options: {e:?}"); + return -libc::EINVAL; + } + } + } else { + None + }; - let firmware_config = FirmwareConfig { path }; - - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => ctx_cfg.get_mut().vmr.set_firmware_config(firmware_config), - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -unsafe fn load_krunfw_payload( - krunfw: &KrunfwBindings, - vmr: &mut VmResources, -) -> Result<(), libloading::Error> { - let mut kernel_guest_addr: u64 = 0; - let mut kernel_entry_addr: u64 = 0; - let mut kernel_size: usize = 0; - let kernel_host_addr = unsafe { - (krunfw.get_kernel)( - &mut kernel_guest_addr as *mut u64, - &mut kernel_entry_addr as *mut u64, - &mut kernel_size as *mut usize, - ) - }; - let kernel_bundle = KernelBundle { - host_addr: kernel_host_addr as u64, - guest_addr: kernel_guest_addr, - entry_addr: kernel_entry_addr, - size: kernel_size, - }; - vmr.set_kernel_bundle(kernel_bundle).unwrap(); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let ctx_cfg = ctx_cfg.get_mut(); - #[cfg(feature = "tee")] - { - let mut qboot_size: usize = 0; - let qboot_host_addr = unsafe { (krunfw.get_qboot)(&mut qboot_size as *mut usize) }; - let qboot_bundle = QbootBundle { - host_addr: qboot_host_addr as u64, - size: qboot_size, - }; - vmr.set_qboot_bundle(qboot_bundle).unwrap(); - - let mut initrd_size: usize = 0; - let initrd_host_addr = unsafe { (krunfw.get_initrd)(&mut initrd_size as *mut usize) }; - let initrd_bundle = InitrdBundle { - host_addr: initrd_host_addr as u64, - size: initrd_size, - }; - vmr.set_initrd_bundle(initrd_bundle).unwrap(); - } - - Ok(()) -} - -#[no_mangle] -pub extern "C" fn krun_setuid(ctx_id: u32, uid: libc::uid_t) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_vmm_uid(uid); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[no_mangle] -pub extern "C" fn krun_setgid(ctx_id: u32, gid: libc::gid_t) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.set_vmm_gid(gid); - } - Entry::Vacant(_) => return -libc::ENOENT, - } - - KRUN_SUCCESS -} - -#[cfg(all(feature = "blk", not(feature = "tee")))] -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_root_disk_remount( - ctx_id: u32, - c_device: *const c_char, - c_fstype: *const c_char, - c_options: *const c_char, -) -> i32 { - let device = match CStr::from_ptr(c_device).to_str() { - Ok(device) => device.to_string(), - Err(e) => { - error!("Error parsing device path: {e:?}"); - return -libc::EINVAL; - } - }; + if ctx_cfg.vmr.fs.iter().any(|fs| fs.fs_id == "/dev/root") { + error!("Root filesystem already configured"); + return -libc::EINVAL; + } - let fstype = if !c_fstype.is_null() { - match CStr::from_ptr(c_fstype).to_str() { - Ok(fstype) => { - if fstype == "auto" { - None - } else { - Some(fstype.to_string()) + if ctx_cfg.block_cfgs.is_empty() { + error!("No block devices configured"); + return -libc::EINVAL; } - } - Err(e) => { - error!("Error parsing fstype: {e:?}"); - return -libc::EINVAL; - } - } - } else { - None - }; - let options = if !c_options.is_null() { - match CStr::from_ptr(c_options).to_str() { - Ok(options) => Some(options.to_string()), - Err(e) => { - error!("Error parsing options: {e:?}"); - return -libc::EINVAL; - } - } - } else { - None - }; + // To boot from a filesystem other than virtiofs, + // we need to setup a temporary root from which init.krun can be executed. + // Otherwise, it would have to be copied to the target filesystem beforehand. + // Instead, init.krun will run from virtiofs and then switch to the real root. + let root_dir_suffix = Alphanumeric.sample_string(&mut rand::rng(), 6); + let empty_root = env::temp_dir().join(format!("krun-empty-root-{root_dir_suffix}")); - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let ctx_cfg = ctx_cfg.get_mut(); + if let Err(e) = std::fs::create_dir_all(&empty_root) { + error!("Failed to create empty root directory: {e:?}"); + return -libc::EINVAL; + } - if ctx_cfg.vmr.fs.iter().any(|fs| fs.fs_id == "/dev/root") { - error!("Root filesystem already configured"); - return -libc::EINVAL; - } + ctx_cfg.vmr.add_fs_device(FsDeviceConfig { + fs_id: "/dev/root".into(), + shared_dir: empty_root.to_string_lossy().into(), + // Default to a conservative 512 MB window. + shm_size: Some(1 << 29), + allow_root_dir_delete: true, + read_only: false, + }); - if ctx_cfg.block_cfgs.is_empty() { - error!("No block devices configured"); - return -libc::EINVAL; + ctx_cfg.set_block_root(device, fstype, options); } + Entry::Vacant(_) => return -libc::ENOENT, + }; - // To boot from a filesystem other than virtiofs, - // we need to setup a temporary root from which init.krun can be executed. - // Otherwise, it would have to be copied to the target filesystem beforehand. - // Instead, init.krun will run from virtiofs and then switch to the real root. - let root_dir_suffix = Alphanumeric.sample_string(&mut rand::rng(), 6); - let empty_root = env::temp_dir().join(format!("krun-empty-root-{root_dir_suffix}")); + KRUN_SUCCESS + } - if let Err(e) = std::fs::create_dir_all(&empty_root) { - error!("Failed to create empty root directory: {e:?}"); - return -libc::EINVAL; + #[no_mangle] + pub extern "C" fn krun_disable_implicit_console(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.disable_implicit_console = true; } - - ctx_cfg.vmr.add_fs_device(FsDeviceConfig { - fs_id: "/dev/root".into(), - shared_dir: empty_root.to_string_lossy().into(), - // Default to a conservative 512 MB window. - shm_size: Some(1 << 29), - allow_root_dir_delete: true, - read_only: false, - }); - - ctx_cfg.set_block_root(device, fstype, options); + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, - }; - KRUN_SUCCESS -} + KRUN_SUCCESS + } -#[no_mangle] -pub extern "C" fn krun_disable_implicit_console(ctx_id: u32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.disable_implicit_console = true; + #[no_mangle] + pub extern "C" fn krun_disable_implicit_vsock(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vsock_config = VsockConfig::Disabled; + } + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + + KRUN_SUCCESS } - KRUN_SUCCESS -} + #[no_mangle] + pub extern "C" fn krun_add_vsock(ctx_id: u32, tsi_features: u32) -> i32 { + let tsi_flags = match TsiFlags::from_bits(tsi_features) { + Some(flags) => flags, + None => return -libc::EINVAL, + }; -#[no_mangle] -pub extern "C" fn krun_disable_implicit_vsock(ctx_id: u32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vsock_config = VsockConfig::Disabled; + if cfg!(target_os = "macos") && tsi_flags.contains(TsiFlags::HIJACK_UNIX) { + error!("TSI hijacking of UNIX sockets is not yet supported on macOS"); + return -libc::EINVAL; } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} - -#[no_mangle] -pub extern "C" fn krun_add_vsock(ctx_id: u32, tsi_features: u32) -> i32 { - let tsi_flags = match TsiFlags::from_bits(tsi_features) { - Some(flags) => flags, - None => return -libc::EINVAL, - }; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + if cfg.vsock_config != VsockConfig::Disabled { + return -libc::EEXIST; + } + cfg.vsock_config = VsockConfig::Explicit { tsi_flags }; + } + Entry::Vacant(_) => return -libc::ENOENT, + } - if cfg!(target_os = "macos") && tsi_flags.contains(TsiFlags::HIJACK_UNIX) { - error!("TSI hijacking of UNIX sockets is not yet supported on macOS"); - return -libc::EINVAL; + KRUN_SUCCESS } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - if cfg.vsock_config != VsockConfig::Disabled { - return -libc::EEXIST; + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_virtio_console_default( + ctx_id: u32, + input_fd: libc::c_int, + output_fd: libc::c_int, + err_fd: libc::c_int, + ) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + + cfg.vmr + .virtio_consoles + .push(VirtioConsoleConfigMode::Autoconfigure( + DefaultVirtioConsoleConfig { + input_fd, + output_fd, + err_fd, + }, + )); } - cfg.vsock_config = VsockConfig::Explicit { tsi_flags }; + Entry::Vacant(_) => return -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, + + KRUN_SUCCESS } - KRUN_SUCCESS -} + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_virtio_console_multiport(ctx_id: u32) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + let console_id = cfg.vmr.virtio_consoles.len() as i32; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_virtio_console_default( - ctx_id: u32, - input_fd: libc::c_int, - output_fd: libc::c_int, - err_fd: libc::c_int, -) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); + cfg.vmr + .virtio_consoles + .push(VirtioConsoleConfigMode::Explicit(Vec::new())); - cfg.vmr - .virtio_consoles - .push(VirtioConsoleConfigMode::Autoconfigure( - DefaultVirtioConsoleConfig { - input_fd, - output_fd, - err_fd, - }, - )); + console_id + } + Entry::Vacant(_) => -libc::ENOENT, } - Entry::Vacant(_) => return -libc::ENOENT, } - KRUN_SUCCESS -} + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_console_port_tty( + ctx_id: u32, + console_id: u32, + name: *const libc::c_char, + tty_fd: libc::c_int, + ) -> i32 { + if tty_fd < 0 { + return -libc::EINVAL; + } + + let name_str = if name.is_null() { + String::new() + } else { + match CStr::from_ptr(name).to_str() { + Ok(s) => s.to_string(), + Err(_) => return -libc::EINVAL, + } + }; -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_virtio_console_multiport(ctx_id: u32) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - let console_id = cfg.vmr.virtio_consoles.len() as i32; + if !BorrowedFd::borrow_raw(tty_fd).is_terminal() { + return -libc::ENOTTY; + } - cfg.vmr - .virtio_consoles - .push(VirtioConsoleConfigMode::Explicit(Vec::new())); + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); - console_id + match cfg.vmr.virtio_consoles.get_mut(console_id as usize) { + Some(VirtioConsoleConfigMode::Explicit(ports)) => { + ports.push(PortConfig::Tty { + name: name_str, + tty_fd, + }); + KRUN_SUCCESS + } + _ => -libc::EINVAL, + } + } + Entry::Vacant(_) => -libc::ENOENT, } - Entry::Vacant(_) => -libc::ENOENT, } -} -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_console_port_tty( - ctx_id: u32, - console_id: u32, - name: *const libc::c_char, - tty_fd: libc::c_int, -) -> i32 { - if tty_fd < 0 { - return -libc::EINVAL; - } + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_console_port_inout( + ctx_id: u32, + console_id: u32, + name: *const c_char, + input_fd: c_int, + output_fd: c_int, + ) -> i32 { + let name_str = if name.is_null() { + String::new() + } else { + match CStr::from_ptr(name).to_str() { + Ok(s) => s.to_string(), + Err(_) => return -libc::EINVAL, + } + }; - let name_str = if name.is_null() { - String::new() - } else { - match CStr::from_ptr(name).to_str() { - Ok(s) => s.to_string(), - Err(_) => return -libc::EINVAL, + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + + match cfg.vmr.virtio_consoles.get_mut(console_id as usize) { + Some(VirtioConsoleConfigMode::Explicit(ports)) => { + ports.push(PortConfig::InOut { + name: name_str, + input_fd, + output_fd, + }); + KRUN_SUCCESS + } + _ => -libc::EINVAL, + } + } + Entry::Vacant(_) => -libc::ENOENT, + } + } + + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_add_serial_console_default( + ctx_id: u32, + input_fd: c_int, + output_fd: c_int, + ) -> i32 { + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.serial_consoles.push(SerialConsoleConfig { + input_fd, + output_fd, + }); + } + Entry::Vacant(_) => return -libc::ENOENT, } - }; - if !BorrowedFd::borrow_raw(tty_fd).is_terminal() { - return -libc::ENOTTY; + KRUN_SUCCESS } - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - - match cfg.vmr.virtio_consoles.get_mut(console_id as usize) { - Some(VirtioConsoleConfigMode::Explicit(ports)) => { - ports.push(PortConfig::Tty { - name: name_str, - tty_fd, - }); - KRUN_SUCCESS - } - _ => -libc::EINVAL, - } - } - Entry::Vacant(_) => -libc::ENOENT, - } -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_console_port_inout( - ctx_id: u32, - console_id: u32, - name: *const c_char, - input_fd: c_int, - output_fd: c_int, -) -> i32 { - let name_str = if name.is_null() { - String::new() - } else { - match CStr::from_ptr(name).to_str() { - Ok(s) => s.to_string(), + #[allow(clippy::missing_safety_doc)] + #[no_mangle] + pub unsafe extern "C" fn krun_set_kernel_console( + ctx_id: u32, + console_id: *const c_char, + ) -> i32 { + let console_id = match CStr::from_ptr(console_id).to_str() { + Ok(id) => id.to_string(), Err(_) => return -libc::EINVAL, + }; + match CTX_MAP.lock().unwrap().entry(ctx_id) { + Entry::Occupied(mut ctx_cfg) => { + let cfg = ctx_cfg.get_mut(); + cfg.vmr.kernel_console = Some(console_id); + } + Entry::Vacant(_) => return -libc::ENOENT, } - }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - - match cfg.vmr.virtio_consoles.get_mut(console_id as usize) { - Some(VirtioConsoleConfigMode::Explicit(ports)) => { - ports.push(PortConfig::InOut { - name: name_str, - input_fd, - output_fd, - }); - KRUN_SUCCESS - } - _ => -libc::EINVAL, - } - } - Entry::Vacant(_) => -libc::ENOENT, - } -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_add_serial_console_default( - ctx_id: u32, - input_fd: c_int, - output_fd: c_int, -) -> i32 { - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.serial_consoles.push(SerialConsoleConfig { - input_fd, - output_fd, - }); - } - Entry::Vacant(_) => return -libc::ENOENT, + KRUN_SUCCESS } - KRUN_SUCCESS -} - -#[allow(clippy::missing_safety_doc)] -#[no_mangle] -pub unsafe extern "C" fn krun_set_kernel_console(ctx_id: u32, console_id: *const c_char) -> i32 { - let console_id = match CStr::from_ptr(console_id).to_str() { - Ok(id) => id.to_string(), - Err(_) => return -libc::EINVAL, - }; - match CTX_MAP.lock().unwrap().entry(ctx_id) { - Entry::Occupied(mut ctx_cfg) => { - let cfg = ctx_cfg.get_mut(); - cfg.vmr.kernel_console = Some(console_id); + #[no_mangle] + #[allow(unreachable_code)] + pub extern "C" fn krun_start_enter(ctx_id: u32) -> i32 { + #[cfg(target_os = "linux")] + { + let prname = match env::var("HOSTNAME") { + Ok(val) => CString::new(format!("VM:{val}")).unwrap(), + Err(_) => CString::new("libkrun VM").unwrap(), + }; + unsafe { libc::prctl(libc::PR_SET_NAME, prname.as_ptr()) }; } - Entry::Vacant(_) => return -libc::ENOENT, - } - KRUN_SUCCESS -} + #[cfg(feature = "aws-nitro")] + return krun_start_enter_nitro(ctx_id); -#[no_mangle] -#[allow(unreachable_code)] -pub extern "C" fn krun_start_enter(ctx_id: u32) -> i32 { - #[cfg(target_os = "linux")] - { - let prname = match env::var("HOSTNAME") { - Ok(val) => CString::new(format!("VM:{val}")).unwrap(), - Err(_) => CString::new("libkrun VM").unwrap(), + let mut event_manager = match EventManager::new() { + Ok(em) => em, + Err(e) => { + error!("Unable to create EventManager: {e:?}"); + return -libc::EINVAL; + } }; - unsafe { libc::prctl(libc::PR_SET_NAME, prname.as_ptr()) }; - } - - #[cfg(feature = "aws-nitro")] - return krun_start_enter_nitro(ctx_id); - - let mut event_manager = match EventManager::new() { - Ok(em) => em, - Err(e) => { - error!("Unable to create EventManager: {e:?}"); - return -libc::EINVAL; - } - }; - let mut ctx_cfg = match CTX_MAP.lock().unwrap().remove(&ctx_id) { - Some(ctx_cfg) => ctx_cfg, - None => return -libc::ENOENT, - }; + let mut ctx_cfg = match CTX_MAP.lock().unwrap().remove(&ctx_id) { + Some(ctx_cfg) => ctx_cfg, + None => return -libc::ENOENT, + }; - if ctx_cfg.vmr.external_kernel.is_none() - && ctx_cfg.vmr.kernel_bundle.is_none() - && ctx_cfg.vmr.firmware_config.is_none() - { - if let Some(ref krunfw) = ctx_cfg.krunfw { - if let Err(err) = unsafe { load_krunfw_payload(krunfw, &mut ctx_cfg.vmr) } { - eprintln!("Can't load libkrunfw symbols: {err}"); + if ctx_cfg.vmr.external_kernel.is_none() + && ctx_cfg.vmr.kernel_bundle.is_none() + && ctx_cfg.vmr.firmware_config.is_none() + { + if let Some(ref krunfw) = ctx_cfg.krunfw { + if let Err(err) = unsafe { load_krunfw_payload(krunfw, &mut ctx_cfg.vmr) } { + eprintln!("Can't load libkrunfw symbols: {err}"); + return -libc::ENOENT; + } + } else { + eprintln!("Couldn't find or load {KRUNFW_NAME}"); return -libc::ENOENT; } - } else { - eprintln!("Couldn't find or load {KRUNFW_NAME}"); - return -libc::ENOENT; } - } - #[cfg(feature = "blk")] - for block_cfg in ctx_cfg.get_block_cfg() { - if ctx_cfg.vmr.add_block_device(block_cfg).is_err() { - error!("Error configuring virtio-blk for block"); - return -libc::EINVAL; + #[cfg(feature = "blk")] + for block_cfg in ctx_cfg.get_block_cfg() { + if ctx_cfg.vmr.add_block_device(block_cfg).is_err() { + error!("Error configuring virtio-blk for block"); + return -libc::EINVAL; + } } - } - /* - * Before krun_start_enter() is called in an encrypted context, the TEE - * config must have been set via krun_set_tee_config_file(). If the TEE - * config is not set by this point, print the relevant error message and - * fail. - */ - #[cfg(feature = "tee")] - if let Some(tee_config) = ctx_cfg.get_tee_config_file() { - if let Err(e) = ctx_cfg.vmr.set_tee_config(tee_config) { - error!("Error setting up TEE config: {e:?}"); + /* + * Before krun_start_enter() is called in an encrypted context, the TEE + * config must have been set via krun_set_tee_config_file(). If the TEE + * config is not set by this point, print the relevant error message and + * fail. + */ + #[cfg(feature = "tee")] + if let Some(tee_config) = ctx_cfg.get_tee_config_file() { + if let Err(e) = ctx_cfg.vmr.set_tee_config(tee_config) { + error!("Error setting up TEE config: {e:?}"); + return -libc::EINVAL; + } + } else { + error!("Missing TEE config file"); return -libc::EINVAL; } - } else { - error!("Missing TEE config file"); - return -libc::EINVAL; - } - - let kernel_cmdline = KernelCmdlineConfig { - prolog: Some(format!("{DEFAULT_KERNEL_CMDLINE} init={INIT_PATH}")), - krun_env: Some(format!( - " {} {} {} {} {}", - ctx_cfg.get_exec_path(), - ctx_cfg.get_workdir(), - ctx_cfg.get_block_root(), - ctx_cfg.get_rlimits(), - ctx_cfg.get_env(), - )), - epilog: Some(format!(" -- {}", ctx_cfg.get_args())), - }; - if ctx_cfg.vmr.set_kernel_cmdline(kernel_cmdline).is_err() { - return -libc::EINVAL; - } + let kernel_cmdline = KernelCmdlineConfig { + prolog: Some(format!("{DEFAULT_KERNEL_CMDLINE} init={INIT_PATH}")), + krun_env: Some(format!( + " {} {} {} {} {}", + ctx_cfg.get_exec_path(), + ctx_cfg.get_workdir(), + ctx_cfg.get_block_root(), + ctx_cfg.get_rlimits(), + ctx_cfg.get_env(), + )), + epilog: Some(format!(" -- {}", ctx_cfg.get_args())), + }; - #[cfg(feature = "net")] - { - if let Some(legacy_net_cfg) = ctx_cfg.legacy_net_cfg.clone() { - let backend = match legacy_net_cfg { - LegacyNetworkConfig::VirtioNetGvproxy(path) => { - VirtioNetBackend::UnixgramPath(path, true) - } - LegacyNetworkConfig::VirtioNetPasst(fd) => VirtioNetBackend::UnixstreamFd(fd), - }; - let mac = ctx_cfg - .legacy_mac - .unwrap_or([0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee]); - create_virtio_net(&mut ctx_cfg, backend, mac, NET_COMPAT_FEATURES); - } - } - - match &ctx_cfg.vsock_config { - VsockConfig::Disabled => (), - VsockConfig::Explicit { tsi_flags } => { - let vsock_device_config = VsockDeviceConfig { - vsock_id: "vsock0".to_string(), - guest_cid: 3, - host_port_map: ctx_cfg.tsi_port_map, - unix_ipc_port_map: ctx_cfg.unix_ipc_port_map.clone(), - tsi_flags: *tsi_flags, - }; - ctx_cfg.vmr.set_vsock_device(vsock_device_config).unwrap(); + if ctx_cfg.vmr.set_kernel_cmdline(kernel_cmdline).is_err() { + return -libc::EINVAL; } - VsockConfig::Implicit => { - // Implicit vsock configuration - use heuristics - // Check if TSI should be enabled based on network configuration - #[cfg(feature = "net")] - let enable_tsi = ctx_cfg.vmr.net.list.is_empty() && ctx_cfg.legacy_net_cfg.is_none(); - #[cfg(not(feature = "net"))] - let enable_tsi = true; - let has_ipc_map = ctx_cfg.unix_ipc_port_map.is_some(); - - if enable_tsi || has_ipc_map { - let (tsi_flags, host_port_map) = if enable_tsi { - (TsiFlags::HIJACK_INET, ctx_cfg.tsi_port_map) - } else { - (TsiFlags::empty(), None) + #[cfg(feature = "net")] + { + if let Some(legacy_net_cfg) = ctx_cfg.legacy_net_cfg.clone() { + let backend = match legacy_net_cfg { + LegacyNetworkConfig::VirtioNetGvproxy(path) => { + VirtioNetBackend::UnixgramPath(path, true) + } + LegacyNetworkConfig::VirtioNetPasst(fd) => VirtioNetBackend::UnixstreamFd(fd), }; + let mac = ctx_cfg + .legacy_mac + .unwrap_or([0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee]); + create_virtio_net(&mut ctx_cfg, backend, mac, NET_COMPAT_FEATURES); + } + } + match &ctx_cfg.vsock_config { + VsockConfig::Disabled => (), + VsockConfig::Explicit { tsi_flags } => { let vsock_device_config = VsockDeviceConfig { vsock_id: "vsock0".to_string(), guest_cid: 3, - host_port_map, + host_port_map: ctx_cfg.tsi_port_map, unix_ipc_port_map: ctx_cfg.unix_ipc_port_map.clone(), - tsi_flags, + tsi_flags: *tsi_flags, }; ctx_cfg.vmr.set_vsock_device(vsock_device_config).unwrap(); } + VsockConfig::Implicit => { + // Implicit vsock configuration - use heuristics + // Check if TSI should be enabled based on network configuration + #[cfg(feature = "net")] + let enable_tsi = + ctx_cfg.vmr.net.list.is_empty() && ctx_cfg.legacy_net_cfg.is_none(); + #[cfg(not(feature = "net"))] + let enable_tsi = true; + + let has_ipc_map = ctx_cfg.unix_ipc_port_map.is_some(); + + if enable_tsi || has_ipc_map { + let (tsi_flags, host_port_map) = if enable_tsi { + (TsiFlags::HIJACK_INET, ctx_cfg.tsi_port_map) + } else { + (TsiFlags::empty(), None) + }; + + let vsock_device_config = VsockDeviceConfig { + vsock_id: "vsock0".to_string(), + guest_cid: 3, + host_port_map, + unix_ipc_port_map: ctx_cfg.unix_ipc_port_map.clone(), + tsi_flags, + }; + ctx_cfg.vmr.set_vsock_device(vsock_device_config).unwrap(); + } + } } - } - if let Some(virgl_flags) = ctx_cfg.gpu_virgl_flags { - ctx_cfg.vmr.set_gpu_virgl_flags(virgl_flags); - } - if let Some(shm_size) = ctx_cfg.gpu_shm_size { - ctx_cfg.vmr.set_gpu_shm_size(shm_size); - } + if let Some(virgl_flags) = ctx_cfg.gpu_virgl_flags { + ctx_cfg.vmr.set_gpu_virgl_flags(virgl_flags); + } + if let Some(shm_size) = ctx_cfg.gpu_shm_size { + ctx_cfg.vmr.set_gpu_shm_size(shm_size); + } - #[cfg(feature = "snd")] - ctx_cfg.vmr.set_snd_device(ctx_cfg.enable_snd); + #[cfg(feature = "snd")] + ctx_cfg.vmr.set_snd_device(ctx_cfg.enable_snd); - if let Some(console_output) = ctx_cfg.console_output { - ctx_cfg.vmr.set_console_output(console_output); - } + if let Some(console_output) = ctx_cfg.console_output { + ctx_cfg.vmr.set_console_output(console_output); + } - if let Some(gid) = ctx_cfg.vmm_gid { - if unsafe { libc::setgid(gid) } != 0 { - error!("Failed to set gid {gid}"); - return -std::io::Error::last_os_error().raw_os_error().unwrap(); + if let Some(gid) = ctx_cfg.vmm_gid { + if unsafe { libc::setgid(gid) } != 0 { + error!("Failed to set gid {gid}"); + return -std::io::Error::last_os_error().raw_os_error().unwrap(); + } } - } - if let Some(uid) = ctx_cfg.vmm_uid { - if unsafe { libc::setuid(uid) } != 0 { - error!("Failed to set uid {uid}"); - return -std::io::Error::last_os_error().raw_os_error().unwrap(); + if let Some(uid) = ctx_cfg.vmm_uid { + if unsafe { libc::setuid(uid) } != 0 { + error!("Failed to set uid {uid}"); + return -std::io::Error::last_os_error().raw_os_error().unwrap(); + } } - } - let (sender, _receiver) = unbounded(); + let (sender, _receiver) = unbounded(); - let _vmm = match vmm::builder::build_microvm( - &ctx_cfg.vmr, - &mut event_manager, - ctx_cfg.shutdown_efd, - sender, - ) { - Ok(vmm) => vmm, - Err(e) => { - error!("Building the microVM failed: {e:?}"); - return -libc::EINVAL; + let _vmm = match vmm::builder::build_microvm( + &ctx_cfg.vmr, + &mut event_manager, + ctx_cfg.shutdown_efd, + sender, + ) { + Ok(vmm) => vmm, + Err(e) => { + error!("Building the microVM failed: {e:?}"); + return -libc::EINVAL; + } + }; + + #[cfg(target_os = "macos")] + if ctx_cfg.gpu_virgl_flags.is_some() { + vmm::worker::start_worker_thread(_vmm.clone(), _receiver).unwrap(); } - }; - #[cfg(target_os = "macos")] - if ctx_cfg.gpu_virgl_flags.is_some() { - vmm::worker::start_worker_thread(_vmm.clone(), _receiver).unwrap(); - } + #[cfg(target_arch = "x86_64")] + if ctx_cfg.vmr.split_irqchip { + vmm::worker::start_worker_thread(_vmm.clone(), _receiver.clone()).unwrap(); + } - #[cfg(target_arch = "x86_64")] - if ctx_cfg.vmr.split_irqchip { + #[cfg(any(feature = "amd-sev", feature = "tdx"))] vmm::worker::start_worker_thread(_vmm.clone(), _receiver.clone()).unwrap(); + + loop { + match event_manager.run() { + Ok(_) => {} + Err(e) => { + error!("Error in EventManager loop: {e:?}"); + return -libc::EINVAL; + } + } + } } - #[cfg(any(feature = "amd-sev", feature = "tdx"))] - vmm::worker::start_worker_thread(_vmm.clone(), _receiver.clone()).unwrap(); + #[cfg(feature = "aws-nitro")] + #[no_mangle] + fn krun_start_enter_nitro(ctx_id: u32) -> i32 { + let ctx_cfg = match CTX_MAP.lock().unwrap().remove(&ctx_id) { + Some(ctx_cfg) => ctx_cfg, + None => return -libc::ENOENT, + }; + + let Ok(enclave) = NitroEnclave::try_from(ctx_cfg) else { + return -libc::EINVAL; + }; - loop { - match event_manager.run() { - Ok(_) => {} + match enclave.run() { + Ok(ret) => ret, Err(e) => { - error!("Error in EventManager loop: {e:?}"); - return -libc::EINVAL; + error!("Error running nitro enclave: {e}"); + + -libc::EINVAL } } } -} -#[cfg(feature = "aws-nitro")] -#[no_mangle] -fn krun_start_enter_nitro(ctx_id: u32) -> i32 { - let ctx_cfg = match CTX_MAP.lock().unwrap().remove(&ctx_id) { - Some(ctx_cfg) => ctx_cfg, - None => return -libc::ENOENT, - }; + // ============================================================================ + // New functions — Unix stubs (full implementations in windows_api.rs) + // ============================================================================ - let Ok(enclave) = NitroEnclave::try_from(ctx_cfg) else { - return -libc::EINVAL; - }; + /// Start VM on a background thread (non-blocking). + /// Not yet implemented on Unix — use krun_start_enter() instead. + #[no_mangle] + pub extern "C" fn krun_start(_ctx_id: u32) -> i32 { + -libc::ENOSYS + } - match enclave.run() { - Ok(ret) => ret, - Err(e) => { - error!("Error running nitro enclave: {e}"); + /// Block until a running VM exits. Returns exit code. + /// Not yet implemented on Unix. + #[no_mangle] + pub extern "C" fn krun_wait(_ctx_id: u32) -> i32 { + -libc::ENOSYS + } - -libc::EINVAL - } + /// Request a running VM to stop (non-blocking). + /// Not yet implemented on Unix. + #[no_mangle] + pub extern "C" fn krun_stop(_ctx_id: u32) -> i32 { + -libc::ENOSYS + } + + /// Get captured console output for a VM. + /// Not yet implemented on Unix. + #[no_mangle] + pub unsafe extern "C" fn krun_get_console_output( + _ctx_id: u32, + _buf: *mut u8, + _buf_size: u32, + ) -> i32 { + -libc::ENOSYS + } + + /// Add a network device backed by a TCP endpoint. + /// Not yet implemented on Unix — use krun_add_net_unixstream/unixgram instead. + #[no_mangle] + pub unsafe extern "C" fn krun_add_net( + _ctx_id: u32, + _c_endpoint: *const c_char, + _c_mac: *const u8, + ) -> i32 { + -libc::ENOSYS } -} +} // mod unix_api diff --git a/src/libkrun/src/windows_api.rs b/src/libkrun/src/windows_api.rs new file mode 100644 index 000000000..a484b1de1 --- /dev/null +++ b/src/libkrun/src/windows_api.rs @@ -0,0 +1,797 @@ +//! Windows C API implementation for libkrun. +//! +//! All functions follow the libkrun convention: +//! - Return 0 on success, negative on error +//! - Context IDs are u32 +//! - Strings are null-terminated C strings +//! +//! On Windows, functions delegate to `vmm::windows::*` instead of the +//! Unix-specific VMM infrastructure. + +use std::ffi::CStr; +use std::os::raw::c_char; +use std::path::PathBuf; + +use vmm::windows::context::{ + self, DiskConfig, FsMount, NetConfig, VsockPort, DISK_FORMAT_QCOW2, DISK_FORMAT_RAW, +}; +use vmm::windows::devices::manager as devices; +use vmm::windows::error::{Result, WkrunError}; +use vmm::windows::types::VmState; + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Convert a Result to a C API return code (0 = success, negative = error). +fn to_c_result(result: Result<()>) -> i32 { + match result { + Ok(()) => 0, + Err(ref e) => { + log::error!("{}", e); + i32::from(e) + } + } +} + +/// Convert a C string to a Rust PathBuf. Returns None for null pointers. +/// +/// # Safety +/// +/// The pointer must be null or point to a valid null-terminated C string. +unsafe fn c_str_to_path(ptr: *const c_char) -> Option { + if ptr.is_null() { + None + } else { + Some(PathBuf::from( + CStr::from_ptr(ptr).to_string_lossy().into_owned(), + )) + } +} + +/// Convert a C string to a Rust String. Returns None for null pointers. +/// +/// # Safety +/// +/// The pointer must be null or point to a valid null-terminated C string. +unsafe fn c_str_to_string(ptr: *const c_char) -> Option { + if ptr.is_null() { + None + } else { + Some(CStr::from_ptr(ptr).to_string_lossy().into_owned()) + } +} + +/// Convert a null-terminated array of C strings to a Vec. +/// +/// # Safety +/// +/// `arr` must be null or point to a null-terminated array of null-terminated C strings. +unsafe fn c_str_array_to_vec(arr: *const *const c_char) -> Vec { + if arr.is_null() { + return Vec::new(); + } + let mut result = Vec::new(); + let mut ptr = arr; + loop { + let s = *ptr; + if s.is_null() { + break; + } + result.push(CStr::from_ptr(s).to_string_lossy().into_owned()); + ptr = ptr.add(1); + } + result +} + +// Maximum number of arguments/environment variables we allow. +const MAX_ARGS: usize = 4096; + +// ============================================================================ +// Logging +// ============================================================================ + +#[no_mangle] +pub extern "C" fn krun_set_log_level(level: u32) -> i32 { + let filter = match level { + 0 => log::LevelFilter::Off, + 1 => log::LevelFilter::Error, + 2 => log::LevelFilter::Warn, + 3 => log::LevelFilter::Info, + 4 => log::LevelFilter::Debug, + 5 => log::LevelFilter::Trace, + _ => return -libc::EINVAL, + }; + log::set_max_level(filter); + 0 +} + +#[no_mangle] +pub unsafe extern "C" fn krun_init_log( + _target: i32, + level: u32, + _style: u32, + _options: u32, +) -> i32 { + let env_filter = match level { + 0 => "off", + 1 => "error", + 2 => "warn", + 3 => "info", + 4 => "debug", + 5 => "trace", + _ => "warn", + }; + let _ = env_logger::Builder::new() + .parse_filters(env_filter) + .try_init(); + 0 +} + +// ============================================================================ +// Context management +// ============================================================================ + +#[no_mangle] +pub extern "C" fn krun_create_ctx() -> i32 { + match context::create_ctx() { + Ok(id) => id as i32, + Err(e) => { + log::error!("krun_create_ctx: {}", e); + -1 + } + } +} + +#[no_mangle] +pub extern "C" fn krun_free_ctx(ctx_id: u32) -> i32 { + to_c_result(context::free_ctx(ctx_id)) +} + +// ============================================================================ +// VM configuration +// ============================================================================ + +#[no_mangle] +pub extern "C" fn krun_set_vm_config(ctx_id: u32, num_vcpus: u8, ram_mib: u32) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + if ctx.state != VmState::Created { + return Err(WkrunError::InvalidState { + expected: "Created", + actual: ctx.state.to_string(), + }); + } + if num_vcpus == 0 { + return Err(WkrunError::Config("num_vcpus must be > 0".into())); + } + if ram_mib == 0 { + return Err(WkrunError::Config("ram_mib must be > 0".into())); + } + ctx.num_vcpus = num_vcpus; + ctx.ram_mib = ram_mib; + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_root(ctx_id: u32, c_root_path: *const c_char) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.root_path = c_str_to_path(c_root_path); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_virtiofs( + ctx_id: u32, + c_tag: *const c_char, + c_path: *const c_char, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + let tag = c_str_to_string(c_tag) + .ok_or_else(|| WkrunError::Config("virtiofs tag cannot be null".into()))?; + let path = c_str_to_path(c_path) + .ok_or_else(|| WkrunError::Config("virtiofs path cannot be null".into()))?; + ctx.fs_mounts.push(FsMount { + tag, + host_path: path, + }); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_virtiofs2( + ctx_id: u32, + c_tag: *const c_char, + c_path: *const c_char, + _port: u32, +) -> i32 { + // On Windows, virtiofs2 is treated the same as virtiofs (no port parameter needed). + krun_add_virtiofs(ctx_id, c_tag, c_path) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_disk2( + ctx_id: u32, + c_block_id: *const c_char, + c_disk_path: *const c_char, + disk_format: u32, + read_only: bool, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + let id = c_str_to_string(c_block_id) + .ok_or_else(|| WkrunError::Config("block_id cannot be null".into()))?; + let path = c_str_to_path(c_disk_path) + .ok_or_else(|| WkrunError::Config("disk_path cannot be null".into()))?; + if disk_format != DISK_FORMAT_RAW && disk_format != DISK_FORMAT_QCOW2 { + return Err(WkrunError::Config(format!( + "unsupported disk format: {}", + disk_format + ))); + } + ctx.disks.push(DiskConfig { + block_id: id, + path, + format: disk_format, + read_only, + }); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_disk( + ctx_id: u32, + c_block_id: *const c_char, + c_disk_path: *const c_char, + read_only: bool, +) -> i32 { + krun_add_disk2(ctx_id, c_block_id, c_disk_path, DISK_FORMAT_RAW, read_only) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_vsock_port2( + ctx_id: u32, + port: u32, + c_filepath: *const c_char, + listen: bool, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + let path = c_str_to_path(c_filepath) + .ok_or_else(|| WkrunError::Config("vsock filepath cannot be null".into()))?; + ctx.vsock_ports.push(VsockPort { + port, + host_path: path, + listen, + host_tcp_port: None, + }); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_exec( + ctx_id: u32, + c_exec_path: *const c_char, + argv: *const *const c_char, + envp: *const *const c_char, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.exec_path = c_str_to_string(c_exec_path); + let args = c_str_array_to_vec(argv); + if args.len() > MAX_ARGS { + return Err(WkrunError::Config(format!( + "too many arguments: {} > {}", + args.len(), + MAX_ARGS + ))); + } + ctx.argv = args; + let env = c_str_array_to_vec(envp); + if env.len() > MAX_ARGS { + return Err(WkrunError::Config(format!( + "too many env vars: {} > {}", + env.len(), + MAX_ARGS + ))); + } + ctx.envp = env; + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_env(ctx_id: u32, c_envp: *const *const c_char) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + let env = c_str_array_to_vec(c_envp); + if env.len() > MAX_ARGS { + return Err(WkrunError::Config(format!( + "too many env vars: {} > {}", + env.len(), + MAX_ARGS + ))); + } + ctx.envp = env; + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_workdir(ctx_id: u32, c_workdir_path: *const c_char) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.workdir = c_str_to_string(c_workdir_path); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_rlimits(ctx_id: u32, c_rlimits: *const *const c_char) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.rlimits = c_str_array_to_vec(c_rlimits); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_console_output(ctx_id: u32, c_filepath: *const c_char) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.console_output = c_str_to_path(c_filepath); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_kernel( + ctx_id: u32, + c_kernel_path: *const c_char, + _format: u32, + c_initramfs: *const c_char, + c_cmdline: *const c_char, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.kernel_path = c_str_to_path(c_kernel_path); + ctx.initramfs_path = c_str_to_path(c_initramfs); + ctx.kernel_cmdline = c_str_to_string(c_cmdline); + Ok(()) + })) +} + +// ============================================================================ +// Networking +// ============================================================================ + +/// Add a network device backed by a TCP endpoint. +/// +/// On Windows, networking uses TCP sockets to a userspace network proxy +/// (e.g., gvproxy). This replaces the Unix-specific `krun_add_net_unixstream` +/// and `krun_add_net_unixgram`. +#[no_mangle] +pub unsafe extern "C" fn krun_add_net( + ctx_id: u32, + c_endpoint: *const c_char, + c_mac: *const u8, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + let path = c_str_to_path(c_endpoint) + .ok_or_else(|| WkrunError::Config("net endpoint cannot be null".into()))?; + let mac = if c_mac.is_null() { + vmm::windows::devices::virtio::net::generate_mac(ctx_id) + } else { + let mut buf = [0u8; 6]; + std::ptr::copy_nonoverlapping(c_mac, buf.as_mut_ptr(), 6); + buf + }; + ctx.net_config = Some(NetConfig { + mac, + socket_path: path, + }); + Ok(()) + })) +} + +/// Unix stream networking — not available on Windows. +#[no_mangle] +pub unsafe extern "C" fn krun_add_net_unixstream( + _ctx_id: u32, + _c_path: *const c_char, + _fd: i32, + _c_mac: *const u8, + _features: u32, + _flags: u32, +) -> i32 { + log::warn!("krun_add_net_unixstream: not available on Windows, use krun_add_net"); + -libc::ENOSYS +} + +/// Unix dgram networking — not available on Windows. +#[no_mangle] +pub unsafe extern "C" fn krun_add_net_unixgram( + _ctx_id: u32, + _c_path: *const c_char, + _fd: i32, + _c_mac: *const u8, + _features: u32, + _flags: u32, +) -> i32 { + log::warn!("krun_add_net_unixgram: not available on Windows, use krun_add_net"); + -libc::ENOSYS +} + +// ============================================================================ +// No-ops on Windows +// ============================================================================ + +#[no_mangle] +pub extern "C" fn krun_setuid(_ctx_id: u32, _uid: u32) -> i32 { + 0 // No-op on Windows +} + +#[no_mangle] +pub extern "C" fn krun_setgid(_ctx_id: u32, _gid: u32) -> i32 { + 0 // No-op on Windows +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_gpu_options(_ctx_id: u32, _virgl_flags: u32) -> i32 { + 0 // No-op +} + +#[no_mangle] +pub extern "C" fn krun_split_irqchip(_ctx_id: u32, _enable: bool) -> i32 { + 0 // No-op on Windows +} + +#[no_mangle] +pub unsafe extern "C" fn krun_disable_tsi(_ctx_id: u32) -> i32 { + 0 // No-op on Windows (no TSI) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_nested_virt(_ctx_id: u32, _enabled: bool) -> i32 { + 0 // No-op on Windows +} + +#[no_mangle] +pub unsafe extern "C" fn krun_check_nested_virt() -> i32 { + 0 // Not supported on Windows +} + +#[no_mangle] +pub extern "C" fn krun_get_max_vcpus() -> i32 { + // WHPX supports up to 64 vCPUs, but we cap at a reasonable default. + 64 +} + +#[no_mangle] +pub extern "C" fn krun_get_shutdown_eventfd(_ctx_id: u32) -> i32 { + -libc::ENOSYS // eventfd not available on Windows +} + +#[no_mangle] +pub extern "C" fn krun_disable_implicit_console(_ctx_id: u32) -> i32 { + 0 // No-op +} + +// Stubs for functions that reference Unix-only features. +#[no_mangle] +pub unsafe extern "C" fn krun_set_root_disk(_ctx_id: u32, _c_disk_path: *const c_char) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_data_disk(_ctx_id: u32, _c_disk_path: *const c_char) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_root_disk_remount( + ctx_id: u32, + device: *const c_char, + fstype: *const c_char, + _options: *const c_char, +) -> i32 { + to_c_result(context::with_ctx_mut(ctx_id, |ctx| { + ctx.root_disk_device = c_str_to_string(device); + ctx.root_disk_fstype = c_str_to_string(fstype); + Ok(()) + })) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_mapped_volumes( + _ctx_id: u32, + _c_mapped_volumes: *const *const c_char, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_port_map(_ctx_id: u32, _c_port_map: *const *const c_char) -> i32 { + 0 // No-op +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_passt_fd(_ctx_id: u32, _fd: i32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_gvproxy_path(_ctx_id: u32, _c_path: *const c_char) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_net_mac(_ctx_id: u32, _c_mac: *const u8) -> i32 { + 0 // No-op, use krun_add_net +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_snd_device(_ctx_id: u32, _enable: bool) -> i32 { + 0 // No-op +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_firmware(_ctx_id: u32, _c_path: *const c_char) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_smbios_oem_strings( + _ctx_id: u32, + _strings: *const *const c_char, +) -> i32 { + 0 // No-op +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_vsock_port( + ctx_id: u32, + port: u32, + c_filepath: *const c_char, +) -> i32 { + krun_add_vsock_port2(ctx_id, port, c_filepath, false) +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_tee_config_file(_ctx_id: u32, _c_filepath: *const c_char) -> i32 { + -libc::ENOSYS +} + +// ============================================================================ +// VM lifecycle +// ============================================================================ + +/// Start and enter the VM (blocking). Returns exit code. +#[no_mangle] +pub extern "C" fn krun_start_enter(ctx_id: u32) -> i32 { + let ctx = match context::take_ctx(ctx_id) { + Ok(ctx) => ctx, + Err(e) => { + log::error!("krun_start_enter: {}", e); + return i32::from(&e); + } + }; + + match vmm::windows::runner::run(ctx) { + Ok(exit_code) => exit_code, + Err(ref e) => { + log::error!("krun_start_enter: {}", e); + i32::from(e) + } + } +} + +/// Start VM on a background thread (non-blocking). Returns 0 on success. +#[no_mangle] +pub extern "C" fn krun_start(ctx_id: u32) -> i32 { + let ctx = match context::take_ctx(ctx_id) { + Ok(ctx) => ctx, + Err(ref e) => { + log::error!("krun_start: {}", e); + return i32::from(e); + } + }; + to_c_result(vmm::windows::runner::start(ctx_id, ctx)) +} + +/// Block until a running VM exits. Returns exit code. +#[no_mangle] +pub extern "C" fn krun_wait(ctx_id: u32) -> i32 { + match vmm::windows::runner::wait(ctx_id) { + Ok(exit_code) => exit_code, + Err(ref e) => { + log::error!("krun_wait: {}", e); + i32::from(e) + } + } +} + +/// Request a running VM to stop (non-blocking). Returns 0 on success. +#[no_mangle] +pub extern "C" fn krun_stop(ctx_id: u32) -> i32 { + to_c_result(vmm::windows::runner::stop(ctx_id)) +} + +/// Get captured console output for a VM. +/// +/// If `buf` is null or `buf_size` is 0, returns the total number of bytes available. +/// Otherwise, copies up to `buf_size` bytes into `buf` and returns the number copied. +/// Returns -1 if the ctx_id has no console buffer. +#[no_mangle] +pub unsafe extern "C" fn krun_get_console_output(ctx_id: u32, buf: *mut u8, buf_size: u32) -> i32 { + let output = match devices::get_console_output(ctx_id) { + Some(data) => data, + None => return -1, + }; + + if buf.is_null() || buf_size == 0 { + return output.len() as i32; + } + + let copy_len = std::cmp::min(output.len(), buf_size as usize); + if copy_len > 0 { + std::ptr::copy_nonoverlapping(output.as_ptr(), buf, copy_len); + } + copy_len as i32 +} + +// ============================================================================ +// Display / Input / Console stubs (not supported on Windows) +// ============================================================================ + +#[no_mangle] +pub extern "C" fn krun_set_display_backend(_ctx_id: u32, _backend: u32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_display(_ctx_id: u32, _width: u32, _height: u32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub extern "C" fn krun_display_set_refresh_rate(_ctx_id: u32, _display_id: u32, _rate: u32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub extern "C" fn krun_display_set_physical_size( + _ctx_id: u32, + _display_id: u32, + _mm_width: u32, + _mm_height: u32, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub extern "C" fn krun_display_set_dpi(_ctx_id: u32, _display_id: u32, _dpi: u32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_display_set_edid( + _ctx_id: u32, + _display_id: u32, + _edid: *const u8, + _edid_size: u32, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_input_device( + _ctx_id: u32, + _c_path: *const c_char, + _input_type: u32, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_input_device_fd(_ctx_id: u32, _input_fd: i32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_virtio_console_default( + _ctx_id: u32, + _port_name: *const c_char, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_virtio_console_multiport(_ctx_id: u32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_console_port_tty( + _ctx_id: u32, + _name: *const c_char, + _port_name: *const c_char, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_console_port_inout( + _ctx_id: u32, + _name: *const c_char, + _port_name: *const c_char, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_add_serial_console_default(_ctx_id: u32) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_set_kernel_console(_ctx_id: u32, _console_id: *const c_char) -> i32 { + -libc::ENOSYS +} + +// ============================================================================ +// Disk format 3 stub +// ============================================================================ + +#[no_mangle] +pub unsafe extern "C" fn krun_add_disk3( + ctx_id: u32, + c_block_id: *const c_char, + c_disk_path: *const c_char, + disk_format: u32, + read_only: bool, + _cache_type: u32, + _sync_mode: u32, +) -> i32 { + // Ignore cache_type and sync_mode on Windows, delegate to disk2. + krun_add_disk2(ctx_id, c_block_id, c_disk_path, disk_format, read_only) +} + +// ============================================================================ +// GPU options 2 stub +// ============================================================================ + +#[no_mangle] +pub unsafe extern "C" fn krun_set_gpu_options2( + _ctx_id: u32, + _virgl_flags: u32, + _shm_size: u64, +) -> i32 { + 0 // No-op +} + +// ============================================================================ +// Nitro / TEE stubs +// ============================================================================ + +#[no_mangle] +pub unsafe extern "C" fn krun_nitro_set_image( + _ctx_id: u32, + _c_image_filepath: *const c_char, +) -> i32 { + -libc::ENOSYS +} + +#[no_mangle] +pub unsafe extern "C" fn krun_nitro_set_start_flags(_ctx_id: u32, _start_flags: u64) -> i32 { + -libc::ENOSYS +} + +// ============================================================================ +// Net tap stubs +// ============================================================================ + +#[no_mangle] +pub unsafe extern "C" fn krun_add_net_tap( + _ctx_id: u32, + _tap_name: *const c_char, + _c_mac: *const u8, +) -> i32 { + -libc::ENOSYS +} diff --git a/src/vm-memory/Cargo.lock b/src/vm-memory/Cargo.lock new file mode 100644 index 000000000..0916609fe --- /dev/null +++ b/src/vm-memory/Cargo.lock @@ -0,0 +1,689 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" + +[[package]] +name = "bumpalo" +version = "3.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "half" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hermit-abi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" + +[[package]] +name = "is-terminal" +version = "0.4.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "js-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.172" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" + +[[package]] +name = "log" +version = "0.4.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" + +[[package]] +name = "matches" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "proc-macro2" +version = "1.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "rustversion" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.140" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "2.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "vm-memory" +version = "0.16.2" +dependencies = [ + "arc-swap", + "bitflags 2.9.1", + "criterion", + "libc", + "matches", + "thiserror", + "vmm-sys-util", + "winapi", +] + +[[package]] +name = "vmm-sys-util" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d21f366bf22bfba3e868349978766a965cbe628c323d58e026be80b8357ab789" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index 10e3d3674..59ba3ab68 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -18,13 +18,25 @@ gpu = ["devices/gpu", "krun_display"] snd = ["devices/snd"] input = ["devices/input", "krun_input"] aws-nitro = [] +efi = [] [dependencies] crossbeam-channel = ">=0.5.15" flate2 = "1.0.35" libc = ">=0.2.39" -linux-loader = { version = "0.13.2", features = ["bzimage", "elf", "pe"] } log = "0.4.0" + +# Dependencies for amd-sev +kbs-types = { version = "0.13.0", optional = true } +serde = { version = "1.0.125", optional = true } +serde_json = { version = "1.0.64", optional = true } +iocuddle = { version = "0.1.1", optional = true } +bitfield = { version = "0.19.4", optional = true } +bitflags = { version = "2.10.0", optional = true } + +# Unix-only dependencies (macOS + Linux) — the upstream Vmm infrastructure +[target.'cfg(unix)'.dependencies] +linux-loader = { version = "0.13.0", features = ["bzimage", "elf", "pe"] } nix = { version = "0.30.1", features = ["fs", "term"] } vm-memory = { version = "0.17.0", features = ["backend-mmap"] } vmm-sys-util = "0.14" @@ -38,19 +50,14 @@ kernel = { package = "krun-kernel", version = "=0.1.0-1.18.0", path = "../kernel utils = { package = "krun-utils", version = "=0.1.0-1.18.0", path = "../utils" } polly = { package = "krun-polly", version = "=0.1.0-1.18.0", path = "../polly" } -# Dependencies for amd-sev -kbs-types = { version = "0.13.0", optional = true } -serde = { version = "1.0.125", optional = true } -serde_json = { version = "1.0.64", optional = true } -iocuddle = { version = "0.1.1", optional = true } -bitfield = { version = "0.19.4", optional = true } -bitflags = { version = "2.10.0", optional = true } - [target.'cfg(target_arch = "x86_64")'.dependencies] bzip2 = "0.5" -cpuid = { package = "krun-cpuid", version = "=0.1.0-1.18.0", path = "../cpuid" } zstd = "0.13" +# cpuid is needed on Unix x86_64 only (upstream VMM uses it) +[target.'cfg(all(target_arch = "x86_64", unix))'.dependencies] +cpuid = { package = "krun-cpuid", version = "=0.1.0-1.18.0", path = "../cpuid" } + [target.'cfg(target_os = "linux")'.dependencies] tdx = { version = "0.1.0", optional = true } kvm-bindings = { version = "0.12", features = ["fam-wrappers"] } @@ -59,5 +66,23 @@ kvm-ioctls = "0.22" [target.'cfg(target_os = "macos")'.dependencies] hvf = { package = "krun-hvf", version = "=0.1.0-1.18.0", path = "../hvf" } -[dev-dependencies] +# Windows-only dependencies (WHPX backend) +[target.'cfg(target_os = "windows")'.dependencies] +thiserror = "2" +windows-sys = { version = "0.61", features = [ + "Win32_System_Hypervisor", + "Win32_System_Memory", + "Win32_System_LibraryLoader", +] } +zerocopy = { version = "0.8", features = ["derive"] } +rand = "0.9" +uds_windows = "1.2" + +[target.'cfg(target_os = "windows")'.dev-dependencies] +env_logger = "0.11" + +[target.'cfg(unix)'.dev-dependencies] devices = { package = "krun-devices", version = "=0.1.0-1.18.0", path = "../devices", features = ["test_utils"] } + +[dev-dependencies] +tempfile = "3" diff --git a/src/vmm/examples/boot_kernel.rs b/src/vmm/examples/boot_kernel.rs new file mode 100644 index 000000000..32179ed30 --- /dev/null +++ b/src/vmm/examples/boot_kernel.rs @@ -0,0 +1,291 @@ +//! Smoke test: boot a Linux kernel inside a WHPX VM using the VMM runner. +//! +//! Usage: +//! boot_kernel.exe [initrd] [options] [-- extra-cmdline-args...] +//! +//! Options: +//! --disk Attach a raw disk image as virtio-blk device +//! --init Set init binary path (kernel `init=` parameter) +//! --root Override root device (e.g., /dev/vda). Default: auto from --disk +//! --fstype Root filesystem type (e.g., ext4). Used with --root +//! --argv Arguments passed to init after `--` separator (repeat for each arg) +//! --vsock-listen : VMM listens on TCP, bridges to guest vsock +//! --vsock-connect : VMM connects to TCP when guest connects to vsock +//! --verbose Enable serial console output (slower boot, useful for debugging) +//! +//! Examples: +//! # Boot with initramfs only (existing behavior) +//! boot_kernel.exe vmlinuz initrd.img +//! +//! # Boot with disk as root, kernel mounts /dev/vda automatically +//! boot_kernel.exe vmlinuz --disk rootfs.img +//! +//! # Boot with disk + explicit init binary +//! boot_kernel.exe vmlinuz --disk rootfs.img --init /bin/sh +//! +//! # Full lifecycle test: disk + init + argv +//! boot_kernel.exe vmlinuz --disk rootfs.img --init /init --argv --listen --argv vsock://2695 + +#[cfg(not(target_os = "windows"))] +fn main() { + eprintln!("boot_kernel: this example requires Windows (WHPX hypervisor)"); +} + +#[cfg(target_os = "windows")] +use std::path::PathBuf; +#[cfg(target_os = "windows")] +use vmm::windows::context::{DiskConfig, VsockPort, DISK_FORMAT_RAW}; + +#[cfg(target_os = "windows")] +fn main() { + // Initialize logging (RUST_LOG controls verbosity). + env_logger::init(); + + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!( + "Usage: {} [initrd] [--disk ] [--init ] \ + [--root ] [--fstype ] [--argv ]... [-- extra-cmdline-args...]", + args[0] + ); + std::process::exit(1); + } + + let kernel_path = PathBuf::from(&args[1]); + if !kernel_path.exists() { + eprintln!("Kernel not found: {}", kernel_path.display()); + std::process::exit(1); + } + + // Parse optional arguments. + let mut initrd_path: Option = None; + let mut disk_path: Option = None; + let mut init_path: Option = None; + let mut root_device: Option = None; + let mut root_fstype: Option = None; + let mut init_argv: Vec = Vec::new(); + let mut vsock_ports: Vec = Vec::new(); + let mut extra_cmdline: Vec<&str> = Vec::new(); + let mut verbose = false; + let mut past_separator = false; + let mut i = 2; + + while i < args.len() { + let arg = &args[i]; + if arg == "--" { + past_separator = true; + i += 1; + continue; + } + if past_separator { + extra_cmdline.push(arg); + i += 1; + continue; + } + match arg.as_str() { + "--disk" => { + i += 1; + if i >= args.len() { + eprintln!("--disk requires a path argument"); + std::process::exit(1); + } + let p = PathBuf::from(&args[i]); + if !p.exists() { + eprintln!("Disk image not found: {}", p.display()); + std::process::exit(1); + } + disk_path = Some(p); + } + "--init" => { + i += 1; + if i >= args.len() { + eprintln!("--init requires a path argument"); + std::process::exit(1); + } + init_path = Some(args[i].clone()); + } + "--root" => { + i += 1; + if i >= args.len() { + eprintln!("--root requires a device argument"); + std::process::exit(1); + } + root_device = Some(args[i].clone()); + } + "--fstype" => { + i += 1; + if i >= args.len() { + eprintln!("--fstype requires a type argument"); + std::process::exit(1); + } + root_fstype = Some(args[i].clone()); + } + "--argv" => { + i += 1; + if i >= args.len() { + eprintln!("--argv requires an argument"); + std::process::exit(1); + } + init_argv.push(args[i].clone()); + } + "--verbose" => { + verbose = true; + } + "--vsock-listen" | "--vsock-connect" => { + let is_listen = arg == "--vsock-listen"; + i += 1; + if i >= args.len() { + eprintln!("{} requires :", arg); + std::process::exit(1); + } + let parts: Vec<&str> = args[i].split(':').collect(); + if parts.len() != 2 { + eprintln!("Expected :, got: {}", args[i]); + std::process::exit(1); + } + let guest_port: u32 = parts[0].parse().unwrap_or_else(|_| { + eprintln!("Invalid guest port: {}", parts[0]); + std::process::exit(1); + }); + let host_port: u16 = parts[1].parse().unwrap_or_else(|_| { + eprintln!("Invalid host port: {}", parts[1]); + std::process::exit(1); + }); + vsock_ports.push(VsockPort { + port: guest_port, + host_path: PathBuf::new(), + listen: is_listen, + host_tcp_port: Some(host_port), + }); + } + _ => { + if initrd_path.is_none() { + let p = PathBuf::from(arg); + if p.exists() { + initrd_path = Some(p); + } else { + eprintln!( + "Warning: initrd not found: {}, treating as cmdline arg", + arg + ); + extra_cmdline.push(arg); + } + } else { + extra_cmdline.push(arg); + } + } + } + i += 1; + } + + // Build the VmContext via the C-API-style context functions. + let ctx_id = vmm::windows::context::create_ctx().expect("create_ctx failed"); + + vmm::windows::context::with_ctx_mut(ctx_id, |ctx| { + ctx.num_vcpus = 1; + ctx.ram_mib = 256; + ctx.kernel_path = Some(kernel_path.clone()); + ctx.initramfs_path = initrd_path.clone(); + + // Attach disk if provided. + if let Some(ref dp) = disk_path { + ctx.disks.push(DiskConfig { + block_id: "root".to_string(), + path: dp.clone(), + format: DISK_FORMAT_RAW, + read_only: false, + }); + } + + // Root disk device override. + ctx.root_disk_device = root_device.clone(); + ctx.root_disk_fstype = root_fstype.clone(); + + // Init binary path and arguments. + ctx.exec_path = init_path.clone(); + ctx.argv = init_argv.clone(); + + // Verbose mode: enable serial console output for debugging. + ctx.verbose = verbose; + + // Extra cmdline args are appended after the base cmdline and MMIO + // device lines that build_kernel_cmdline() generates automatically. + if !extra_cmdline.is_empty() { + ctx.kernel_cmdline = Some(extra_cmdline.join(" ")); + } + + Ok(()) + }) + .expect("configure ctx failed"); + + println!("=== WHPX Smoke Test ==="); + println!("Kernel: {}", kernel_path.display()); + println!( + "Initrd: {}", + initrd_path + .as_ref() + .map(|p| p.display().to_string()) + .unwrap_or_else(|| "(none)".to_string()) + ); + println!( + "Disk: {}", + disk_path + .as_ref() + .map(|p| p.display().to_string()) + .unwrap_or_else(|| "(none)".to_string()) + ); + if let Some(ref root) = root_device { + println!( + "Root: {} (fstype: {})", + root, + root_fstype.as_deref().unwrap_or("auto") + ); + } + if let Some(ref init) = init_path { + println!("Init: {}", init); + } + if !init_argv.is_empty() { + println!("Argv: {:?}", init_argv); + } + if verbose { + println!("Verbose: enabled (serial console on, slower boot)"); + } + for vp in &vsock_ports { + let host_port = vp.host_tcp_port.unwrap_or(vp.port as u16); + if vp.listen { + println!( + "Vsock: guest:{} <- TCP listen:{} (host→guest)", + vp.port, host_port + ); + } else { + println!( + "Vsock: guest:{} -> TCP connect:127.0.0.1:{} (guest→host)", + vp.port, host_port + ); + } + } + + // Move vsock ports into context (after printing, since VsockPort doesn't impl Clone). + if !vsock_ports.is_empty() { + vmm::windows::context::with_ctx_mut(ctx_id, |ctx| { + ctx.vsock_ports = vsock_ports; + Ok(()) + }) + .expect("set vsock_ports failed"); + } + + // Take the context out of the global map and run synchronously. + let ctx = vmm::windows::context::take_ctx(ctx_id).expect("take_ctx failed"); + + println!("Starting VM..."); + match vmm::windows::runner::run(ctx) { + Ok(code) => { + println!("VM exited with code {}", code); + std::process::exit(code); + } + Err(e) => { + eprintln!("VM error: {}", e); + std::process::exit(1); + } + } +} diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index 80609eb6b..2b4671380 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -88,7 +88,7 @@ use nix::unistd::isatty; use polly::event_manager::{Error as EventManagerError, EventManager}; use utils::eventfd::EventFd; use utils::worker_message::WorkerMessage; -#[cfg(all(target_arch = "x86_64", not(feature = "tee")))] +#[cfg(all(target_arch = "x86_64", not(feature = "efi"), not(feature = "tee")))] use vm_memory::mmap::MmapRegion; #[cfg(not(any(feature = "tee", feature = "aws-nitro")))] use vm_memory::Address; diff --git a/src/vmm/src/lib.rs b/src/vmm/src/lib.rs index 0f0f8c258..c977b9db3 100644 --- a/src/vmm/src/lib.rs +++ b/src/vmm/src/lib.rs @@ -13,15 +13,23 @@ #[macro_use] extern crate log; -/// Handles setup and initialization a `Vmm` object. +// ── Windows WHPX backend ───────────────────────────────────────────────────── +// Self-contained module — does NOT use the upstream Vmm infrastructure. +#[cfg(target_os = "windows")] +pub mod windows; + +// ── Unix (Linux + macOS) upstream VMM infrastructure ───────────────────────── +// Everything below is the original Firecracker/libkrun VMM code that depends +// on KVM (Linux), Hypervisor.framework (macOS), and POSIX APIs. +#[cfg(unix)] pub mod builder; +#[cfg(unix)] pub(crate) mod device_manager; -/// Resource store for configured microVM resources. +#[cfg(unix)] pub mod resources; -/// Signal handling utilities. #[cfg(target_os = "linux")] pub mod signal_handler; -/// Wrappers over structures used to configure the VMM. +#[cfg(unix)] pub mod vmm_config; #[cfg(target_os = "linux")] @@ -30,39 +38,57 @@ mod linux; use crate::linux::vstate; #[cfg(target_os = "macos")] mod macos; +#[cfg(unix)] mod terminal; +#[cfg(unix)] pub mod worker; #[cfg(target_os = "macos")] use macos::vstate; +#[cfg(unix)] use std::fmt::{Display, Formatter}; +#[cfg(unix)] use std::io; +#[cfg(unix)] use std::os::unix::io::AsRawFd; +#[cfg(unix)] use std::sync::atomic::{AtomicI32, Ordering}; +#[cfg(unix)] use std::sync::{Arc, Mutex}; #[cfg(target_os = "linux")] use std::time::Duration; -#[cfg(target_arch = "x86_64")] +#[cfg(all(unix, target_arch = "x86_64"))] use crate::device_manager::legacy::PortIODeviceManager; +#[cfg(unix)] use crate::device_manager::mmio::MMIODeviceManager; #[cfg(target_os = "linux")] use crate::vstate::VcpuEvent; +#[cfg(unix)] use crate::vstate::{Vcpu, VcpuHandle, VcpuResponse, Vm}; +#[cfg(unix)] use arch::{ArchMemoryInfo, InitrdConfig}; #[cfg(target_os = "macos")] use crossbeam_channel::Sender; #[cfg(any(target_arch = "aarch64", target_arch = "riscv64"))] use devices::fdt; +#[cfg(unix)] use devices::legacy::IrqChip; +#[cfg(unix)] use devices::virtio::VmmExitObserver; +#[cfg(unix)] use devices::{BusDevice, DeviceType}; +#[cfg(unix)] use kernel::cmdline::Cmdline as KernelCmdline; +#[cfg(unix)] use polly::event_manager::{self, EventManager, Subscriber}; +#[cfg(unix)] use utils::epoll::{EpollEvent, EventSet}; +#[cfg(unix)] use utils::eventfd::EventFd; +#[cfg(unix)] use vm_memory::GuestMemoryMmap; /// Success exit code. @@ -85,6 +111,7 @@ pub const FC_EXIT_CODE_ARG_PARSING: u8 = 153; /// Errors associated with the VMM internal logic. These errors cannot be generated by direct user /// input, but can result from bad configuration of the host (for example if Firecracker doesn't /// have permissions to open the KVM fd). +#[cfg(unix)] #[derive(Debug)] pub enum Error { /// This error is thrown by the minimal boot loader implementation. @@ -137,6 +164,7 @@ pub enum Error { VmmObserverTeardown(utils::errno::Error), } +#[cfg(unix)] impl Display for Error { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { use self::Error::*; @@ -177,6 +205,7 @@ impl Display for Error { } /// Trait for objects that need custom initialization and teardown during the Vmm lifetime. +#[cfg(unix)] pub trait VmmEventsObserver { /// This function will be called during microVm boot. fn on_vmm_boot(&mut self) -> std::result::Result<(), utils::errno::Error> { @@ -189,9 +218,11 @@ pub trait VmmEventsObserver { } /// Shorthand result type for internal VMM commands. +#[cfg(unix)] pub type Result = std::result::Result; /// Contains the state and associated methods required for the Firecracker VMM. +#[cfg(unix)] pub struct Vmm { // Guest VM core resources. guest_memory: GuestMemoryMmap, @@ -211,6 +242,7 @@ pub struct Vmm { pio_device_manager: PortIODeviceManager, } +#[cfg(unix)] impl Vmm { /// Gets the the specified bus device. pub fn get_bus_device( @@ -396,6 +428,7 @@ impl Vmm { } } +#[cfg(unix)] impl Subscriber for Vmm { /// Handle a read event (EPOLLIN). fn process(&mut self, event: &EpollEvent, _: &mut EventManager) { diff --git a/src/vmm/src/windows/boot/acpi.rs b/src/vmm/src/windows/boot/acpi.rs new file mode 100644 index 000000000..ba18fb1ce --- /dev/null +++ b/src/vmm/src/windows/boot/acpi.rs @@ -0,0 +1,514 @@ +//! Minimal ACPI table generation for WHPX guest boot. +//! +//! Generates RSDP, RSDT, FADT, DSDT, and MADT tables so the Linux kernel can: +//! - Discover the PM1a_CNT register for clean ACPI S5 shutdown +//! - Discover the IOAPIC and LAPIC for APIC-mode interrupt routing + +/// Total size of the ACPI region in guest memory. +pub const ACPI_REGION_SIZE: u64 = 0x400; // 1024 bytes + +// Table offsets within the ACPI region. +const RSDP_OFFSET: usize = 0x00; +const RSDT_OFFSET: usize = 0x20; +const FADT_OFFSET: usize = 0x60; +const DSDT_OFFSET: usize = 0x100; +const MADT_OFFSET: usize = 0x140; + +// Table sizes. +const RSDP_SIZE: usize = 20; +const RSDT_HEADER_SIZE: usize = 36; +const RSDT_ENTRIES: usize = 2; // FADT + MADT +const RSDT_SIZE: usize = RSDT_HEADER_SIZE + RSDT_ENTRIES * 4; // 36 + 8 = 44 +const FADT_SIZE: usize = 116; +const DSDT_HEADER_SIZE: usize = 36; + +/// MADT structure sizes. +const MADT_HEADER_SIZE: usize = 44; // 36-byte ACPI header + 4-byte Local APIC Address + 4-byte Flags +const MADT_LAPIC_ENTRY_SIZE: usize = 8; // Type 0: Processor Local APIC +const MADT_IOAPIC_ENTRY_SIZE: usize = 12; // Type 1: I/O APIC +const MADT_ISO_ENTRY_SIZE: usize = 10; // Type 2: Interrupt Source Override + +/// Compute the MADT size for a given number of vCPUs. +const fn madt_size(num_vcpus: u8) -> usize { + MADT_HEADER_SIZE + + MADT_LAPIC_ENTRY_SIZE * (num_vcpus as usize) + + MADT_IOAPIC_ENTRY_SIZE + + MADT_ISO_ENTRY_SIZE +} + +/// MADT size for the default single-vCPU case (used for static offset validation). +const MADT_SIZE_1: usize = madt_size(1); + +// ACPI PM1a I/O port addresses (must match manager.rs constants). +const PM1A_EVT_BLK: u32 = 0x600; +const PM1A_CNT_BLK: u32 = 0x604; + +/// SCI interrupt number for ACPI. +/// +/// Must not conflict with timer (IRQ 0), serial (IRQ 4), or +/// virtio-MMIO devices (IRQ 5-9). IRQ 11 is unused. +const SCI_INT: u16 = 11; + +/// IOAPIC base address (must match memory.rs). +const IOAPIC_BASE: u32 = 0xFEC0_0000; + +/// LAPIC base address (must match memory.rs). +const LAPIC_BASE: u32 = 0xFEE0_0000; + +/// AML bytecode for the `\_S5_` sleep package. +/// +/// Encodes: `Name(\_S5_, Package(4) { 5, 5, 0, 0 })` +/// - `08` = NameOp +/// - `5C 5F 53 35 5F` = `\_S5_` +/// - `12 0A 04` = Package, 10 bytes, 4 elements +/// - `0A 05` = ByteConst 5 (SLP_TYPa) +/// - `0A 05` = ByteConst 5 (SLP_TYPb) +/// - `00` = zero +/// - `00` = zero +const S5_AML: &[u8] = &[ + 0x08, 0x5C, 0x5F, 0x53, 0x35, 0x5F, 0x12, 0x0A, 0x04, 0x0A, 0x05, 0x0A, 0x05, 0x00, 0x00, +]; + +/// Build ACPI tables (RSDP, RSDT, FADT, DSDT, MADT) for the given base address. +/// +/// Returns a `Vec` of exactly `ACPI_REGION_SIZE` bytes. The caller +/// writes this to guest memory at `acpi_base`. +/// +/// `num_vcpus` determines how many LAPIC entries are generated in the MADT. +pub fn build_acpi_tables(acpi_base: u64, num_vcpus: u8) -> Vec { + let mut region = vec![0u8; ACPI_REGION_SIZE as usize]; + + let rsdt_addr = acpi_base + RSDT_OFFSET as u64; + let fadt_addr = acpi_base + FADT_OFFSET as u64; + let dsdt_addr = acpi_base + DSDT_OFFSET as u64; + let madt_addr = acpi_base + MADT_OFFSET as u64; + + // ---- RSDP (20 bytes at offset 0x00) ---- + let rsdp = &mut region[RSDP_OFFSET..RSDP_OFFSET + RSDP_SIZE]; + rsdp[0..8].copy_from_slice(b"RSD PTR "); // Signature + // rsdp[8] = checksum (computed below) + rsdp[9..15].copy_from_slice(b"BOXLTE"); // OEMID + rsdp[15] = 0; // Revision: ACPI 1.0 + rsdp[16..20].copy_from_slice(&(rsdt_addr as u32).to_le_bytes()); // RsdtAddress + acpi_checksum(&mut region[RSDP_OFFSET..RSDP_OFFSET + RSDP_SIZE], 8); + + // ---- RSDT (44 bytes at offset 0x20) ---- + let rsdt = &mut region[RSDT_OFFSET..RSDT_OFFSET + RSDT_SIZE]; + rsdt[0..4].copy_from_slice(b"RSDT"); // Signature + rsdt[4..8].copy_from_slice(&(RSDT_SIZE as u32).to_le_bytes()); // Length + rsdt[8] = 1; // Revision + // rsdt[9] = checksum (computed below) + rsdt[10..16].copy_from_slice(b"BOXLTE"); // OEMID + rsdt[16..24].copy_from_slice(b"BOXLITEV"); // OEM Table ID + rsdt[24..28].copy_from_slice(&1u32.to_le_bytes()); // OEM Revision + rsdt[28..32].copy_from_slice(b"BXLT"); // Creator ID + rsdt[32..36].copy_from_slice(&1u32.to_le_bytes()); // Creator Revision + // Entry[0]: pointer to FADT + rsdt[36..40].copy_from_slice(&(fadt_addr as u32).to_le_bytes()); + // Entry[1]: pointer to MADT + rsdt[40..44].copy_from_slice(&(madt_addr as u32).to_le_bytes()); + acpi_checksum(&mut region[RSDT_OFFSET..RSDT_OFFSET + RSDT_SIZE], 9); + + // ---- FADT (116 bytes at offset 0x60) ---- + let fadt = &mut region[FADT_OFFSET..FADT_OFFSET + FADT_SIZE]; + fadt[0..4].copy_from_slice(b"FACP"); // Signature (NOT "FADT") + fadt[4..8].copy_from_slice(&(FADT_SIZE as u32).to_le_bytes()); // Length + fadt[8] = 1; // Revision + // fadt[9] = checksum (computed below) + fadt[10..16].copy_from_slice(b"BOXLTE"); // OEMID + fadt[16..24].copy_from_slice(b"BOXLITEV"); // OEM Table ID + fadt[24..28].copy_from_slice(&1u32.to_le_bytes()); // OEM Revision + fadt[28..32].copy_from_slice(b"BXLT"); // Creator ID + fadt[32..36].copy_from_slice(&1u32.to_le_bytes()); // Creator Revision + // FACS pointer (offset 36) — 0, not needed for shutdown. + // DSDT pointer (offset 40). + fadt[40..44].copy_from_slice(&(dsdt_addr as u32).to_le_bytes()); + // SCI_INT (offset 46) — interrupt for ACPI System Control. + fadt[46..48].copy_from_slice(&SCI_INT.to_le_bytes()); + // PM1a_EVT_BLK (offset 56). + fadt[56..60].copy_from_slice(&PM1A_EVT_BLK.to_le_bytes()); + // PM1a_CNT_BLK (offset 64). + fadt[64..68].copy_from_slice(&PM1A_CNT_BLK.to_le_bytes()); + // PM1_EVT_LEN (offset 88). + fadt[88] = 4; + // PM1_CNT_LEN (offset 89). + fadt[89] = 2; + acpi_checksum(&mut region[FADT_OFFSET..FADT_OFFSET + FADT_SIZE], 9); + + // ---- DSDT (header + AML at offset 0x100) ---- + let dsdt_size = DSDT_HEADER_SIZE + S5_AML.len(); + let dsdt = &mut region[DSDT_OFFSET..DSDT_OFFSET + dsdt_size]; + dsdt[0..4].copy_from_slice(b"DSDT"); // Signature + dsdt[4..8].copy_from_slice(&(dsdt_size as u32).to_le_bytes()); // Length + dsdt[8] = 1; // Revision + // dsdt[9] = checksum (computed below) + dsdt[10..16].copy_from_slice(b"BOXLTE"); // OEMID + dsdt[16..24].copy_from_slice(b"BOXLITEV"); // OEM Table ID + dsdt[24..28].copy_from_slice(&1u32.to_le_bytes()); // OEM Revision + dsdt[28..32].copy_from_slice(b"BXLT"); // Creator ID + dsdt[32..36].copy_from_slice(&1u32.to_le_bytes()); // Creator Revision + // AML body: \_S5_ package. + dsdt[DSDT_HEADER_SIZE..DSDT_HEADER_SIZE + S5_AML.len()].copy_from_slice(S5_AML); + acpi_checksum(&mut region[DSDT_OFFSET..DSDT_OFFSET + dsdt_size], 9); + + // ---- MADT (Multiple APIC Description Table) at offset 0x140 ---- + let madt_sz = madt_size(num_vcpus); + assert!( + MADT_OFFSET + madt_sz <= ACPI_REGION_SIZE as usize, + "MADT ({} bytes for {} vCPUs) exceeds ACPI region", + madt_sz, + num_vcpus, + ); + build_madt(&mut region[MADT_OFFSET..MADT_OFFSET + madt_sz], num_vcpus); + + region +} + +/// Build the MADT (Multiple APIC Description Table). +/// +/// Tells the Linux kernel about the Local APIC(s) and I/O APIC. +/// +/// Structure: +/// - Header (44 bytes): standard ACPI header + LAPIC address + flags +/// - N x Local APIC entries (type 0, 8 bytes each): one per vCPU +/// - I/O APIC entry (type 1, 12 bytes): IOAPIC ID 0, base 0xFEC00000 +/// - Interrupt Source Override (type 2, 10 bytes): IRQ 0 → GSI 2 +fn build_madt(madt: &mut [u8], num_vcpus: u8) { + let total_size = madt.len(); + + // ACPI header. + madt[0..4].copy_from_slice(b"APIC"); // Signature + madt[4..8].copy_from_slice(&(total_size as u32).to_le_bytes()); // Length + madt[8] = 1; // Revision + // madt[9] = checksum (computed below) + madt[10..16].copy_from_slice(b"BOXLTE"); // OEMID + madt[16..24].copy_from_slice(b"BOXLITEV"); // OEM Table ID + madt[24..28].copy_from_slice(&1u32.to_le_bytes()); // OEM Revision + madt[28..32].copy_from_slice(b"BXLT"); // Creator ID + madt[32..36].copy_from_slice(&1u32.to_le_bytes()); // Creator Revision + + // Local APIC Address (offset 36, 4 bytes). + madt[36..40].copy_from_slice(&LAPIC_BASE.to_le_bytes()); + + // Flags (offset 40, 4 bytes): PCAT_COMPAT = 1 (dual 8259 PICs present). + madt[40..44].copy_from_slice(&1u32.to_le_bytes()); + + // --- N x Processor Local APIC entries (type 0, 8 bytes each) --- + let mut off = MADT_HEADER_SIZE; + for i in 0..num_vcpus { + madt[off] = 0; // Entry type: Processor Local APIC + madt[off + 1] = MADT_LAPIC_ENTRY_SIZE as u8; // Length + madt[off + 2] = i; // ACPI Processor ID + madt[off + 3] = i; // APIC ID + madt[off + 4..off + 8].copy_from_slice(&1u32.to_le_bytes()); // Flags: enabled + off += MADT_LAPIC_ENTRY_SIZE; + } + + // --- I/O APIC entry (type 1, 12 bytes) --- + madt[off] = 1; // Entry type: I/O APIC + madt[off + 1] = MADT_IOAPIC_ENTRY_SIZE as u8; // Length + madt[off + 2] = 0; // I/O APIC ID + madt[off + 3] = 0; // Reserved + madt[off + 4..off + 8].copy_from_slice(&IOAPIC_BASE.to_le_bytes()); // I/O APIC Address + madt[off + 8..off + 12].copy_from_slice(&0u32.to_le_bytes()); // Global System Interrupt Base + off += MADT_IOAPIC_ENTRY_SIZE; + + // --- Interrupt Source Override (type 2, 10 bytes) --- + // Standard x86 convention: PIT timer (IRQ 0) routes to IOAPIC pin 2. + madt[off] = 2; // Entry type: Interrupt Source Override + madt[off + 1] = MADT_ISO_ENTRY_SIZE as u8; // Length + madt[off + 2] = 0; // Bus: ISA + madt[off + 3] = 0; // Source: IRQ 0 (PIT timer) + madt[off + 4..off + 8].copy_from_slice(&2u32.to_le_bytes()); // Global System Interrupt: 2 + madt[off + 8..off + 10].copy_from_slice(&0u16.to_le_bytes()); // Flags: conforming + + acpi_checksum(madt, 9); +} + +/// Compute ACPI checksum and store it at `checksum_offset`. +/// +/// The checksum byte is set so that the sum of all bytes in the table +/// equals zero (mod 256). +fn acpi_checksum(table: &mut [u8], checksum_offset: usize) { + table[checksum_offset] = 0; + let sum: u8 = table.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + table[checksum_offset] = 0u8.wrapping_sub(sum); +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_BASE: u64 = 0xE0000; + + #[test] + fn test_rsdp_signature_and_checksum() { + let region = build_acpi_tables(TEST_BASE, 1); + let rsdp = ®ion[RSDP_OFFSET..RSDP_OFFSET + RSDP_SIZE]; + + assert_eq!(&rsdp[0..8], b"RSD PTR "); + + let sum: u8 = rsdp.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "RSDP checksum must be zero"); + } + + #[test] + fn test_rsdt_signature_and_length() { + let region = build_acpi_tables(TEST_BASE, 1); + let rsdt = ®ion[RSDT_OFFSET..RSDT_OFFSET + RSDT_SIZE]; + + assert_eq!(&rsdt[0..4], b"RSDT"); + let length = u32::from_le_bytes(rsdt[4..8].try_into().unwrap()); + assert_eq!(length, RSDT_SIZE as u32); + + let sum: u8 = rsdt.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "RSDT checksum must be zero"); + } + + #[test] + fn test_rsdt_has_two_entries() { + let region = build_acpi_tables(TEST_BASE, 1); + let rsdt = ®ion[RSDT_OFFSET..RSDT_OFFSET + RSDT_SIZE]; + + // Entry[0]: FADT pointer. + let fadt_ptr = u32::from_le_bytes(rsdt[36..40].try_into().unwrap()); + assert_eq!(fadt_ptr, (TEST_BASE + FADT_OFFSET as u64) as u32); + + // Entry[1]: MADT pointer. + let madt_ptr = u32::from_le_bytes(rsdt[40..44].try_into().unwrap()); + assert_eq!(madt_ptr, (TEST_BASE + MADT_OFFSET as u64) as u32); + } + + #[test] + fn test_fadt_signature_and_pm1a_cnt() { + let region = build_acpi_tables(TEST_BASE, 1); + let fadt = ®ion[FADT_OFFSET..FADT_OFFSET + FADT_SIZE]; + + assert_eq!(&fadt[0..4], b"FACP"); + + let pm1a_cnt = u32::from_le_bytes(fadt[64..68].try_into().unwrap()); + assert_eq!(pm1a_cnt, 0x604); + + let pm1a_evt = u32::from_le_bytes(fadt[56..60].try_into().unwrap()); + assert_eq!(pm1a_evt, 0x600); + + assert_eq!(fadt[88], 4, "PM1_EVT_LEN"); + assert_eq!(fadt[89], 2, "PM1_CNT_LEN"); + + let sum: u8 = fadt.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "FADT checksum must be zero"); + } + + #[test] + fn test_dsdt_contains_s5_package() { + let region = build_acpi_tables(TEST_BASE, 1); + let dsdt_size = DSDT_HEADER_SIZE + S5_AML.len(); + let dsdt = ®ion[DSDT_OFFSET..DSDT_OFFSET + dsdt_size]; + + assert_eq!(&dsdt[0..4], b"DSDT"); + + // Verify \_S5_ AML is present. + let aml = &dsdt[DSDT_HEADER_SIZE..]; + assert_eq!(aml, S5_AML); + + let sum: u8 = dsdt.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "DSDT checksum must be zero"); + } + + #[test] + fn test_total_region_size() { + let region = build_acpi_tables(TEST_BASE, 1); + assert_eq!(region.len(), ACPI_REGION_SIZE as usize); + } + + #[test] + fn test_rsdp_points_to_rsdt() { + let region = build_acpi_tables(TEST_BASE, 1); + let rsdp = ®ion[RSDP_OFFSET..RSDP_OFFSET + RSDP_SIZE]; + let rsdt_addr = u32::from_le_bytes(rsdp[16..20].try_into().unwrap()); + assert_eq!(rsdt_addr, (TEST_BASE + RSDT_OFFSET as u64) as u32); + } + + #[test] + fn test_fadt_sci_int() { + let region = build_acpi_tables(TEST_BASE, 1); + let fadt = ®ion[FADT_OFFSET..FADT_OFFSET + FADT_SIZE]; + let sci_int = u16::from_le_bytes(fadt[46..48].try_into().unwrap()); + assert_eq!(sci_int, 11, "SCI_INT must be on an unused IRQ"); + } + + #[test] + fn test_fadt_points_to_dsdt() { + let region = build_acpi_tables(TEST_BASE, 1); + let fadt = ®ion[FADT_OFFSET..FADT_OFFSET + FADT_SIZE]; + let dsdt_addr = u32::from_le_bytes(fadt[40..44].try_into().unwrap()); + assert_eq!(dsdt_addr, (TEST_BASE + DSDT_OFFSET as u64) as u32); + } + + // ---- MADT tests ---- + + #[test] + fn test_madt_signature_and_checksum() { + let region = build_acpi_tables(TEST_BASE, 1); + let madt = ®ion[MADT_OFFSET..MADT_OFFSET + MADT_SIZE_1]; + + assert_eq!(&madt[0..4], b"APIC"); + + let length = u32::from_le_bytes(madt[4..8].try_into().unwrap()); + assert_eq!(length, MADT_SIZE_1 as u32); + + let sum: u8 = madt.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "MADT checksum must be zero"); + } + + #[test] + fn test_madt_lapic_address() { + let region = build_acpi_tables(TEST_BASE, 1); + let madt = ®ion[MADT_OFFSET..MADT_OFFSET + MADT_SIZE_1]; + + let lapic_addr = u32::from_le_bytes(madt[36..40].try_into().unwrap()); + assert_eq!(lapic_addr, LAPIC_BASE); + } + + #[test] + fn test_madt_pcat_compat_flag() { + let region = build_acpi_tables(TEST_BASE, 1); + let madt = ®ion[MADT_OFFSET..MADT_OFFSET + MADT_SIZE_1]; + + let flags = u32::from_le_bytes(madt[40..44].try_into().unwrap()); + assert_eq!(flags, 1, "PCAT_COMPAT flag must be set"); + } + + #[test] + fn test_madt_lapic_entry() { + let region = build_acpi_tables(TEST_BASE, 1); + let off = MADT_OFFSET + MADT_HEADER_SIZE; + + assert_eq!(region[off], 0, "entry type: Local APIC"); + assert_eq!(region[off + 1], 8, "entry length"); + assert_eq!(region[off + 2], 0, "ACPI Processor ID"); + assert_eq!(region[off + 3], 0, "APIC ID"); + let flags = u32::from_le_bytes(region[off + 4..off + 8].try_into().unwrap()); + assert_eq!(flags, 1, "enabled flag"); + } + + #[test] + fn test_madt_ioapic_entry() { + let region = build_acpi_tables(TEST_BASE, 1); + let off = MADT_OFFSET + MADT_HEADER_SIZE + MADT_LAPIC_ENTRY_SIZE; + + assert_eq!(region[off], 1, "entry type: I/O APIC"); + assert_eq!(region[off + 1], 12, "entry length"); + assert_eq!(region[off + 2], 0, "I/O APIC ID"); + let ioapic_addr = u32::from_le_bytes(region[off + 4..off + 8].try_into().unwrap()); + assert_eq!(ioapic_addr, IOAPIC_BASE); + let gsi_base = u32::from_le_bytes(region[off + 8..off + 12].try_into().unwrap()); + assert_eq!(gsi_base, 0, "GSI base must be 0"); + } + + #[test] + fn test_madt_interrupt_source_override() { + let region = build_acpi_tables(TEST_BASE, 1); + let off = MADT_OFFSET + MADT_HEADER_SIZE + MADT_LAPIC_ENTRY_SIZE + MADT_IOAPIC_ENTRY_SIZE; + + assert_eq!(region[off], 2, "entry type: Interrupt Source Override"); + assert_eq!(region[off + 1], 10, "entry length"); + assert_eq!(region[off + 2], 0, "bus: ISA"); + assert_eq!(region[off + 3], 0, "source: IRQ 0"); + let gsi = u32::from_le_bytes(region[off + 4..off + 8].try_into().unwrap()); + assert_eq!(gsi, 2, "GSI: IRQ 0 → pin 2"); + let flags = u16::from_le_bytes(region[off + 8..off + 10].try_into().unwrap()); + assert_eq!(flags, 0, "conforming polarity/trigger"); + } + + #[test] + fn test_tables_do_not_overlap() { + // Verify no ACPI tables overlap each other. + let dsdt_size = DSDT_HEADER_SIZE + S5_AML.len(); + let tables = [ + ("RSDP", RSDP_OFFSET, RSDP_OFFSET + RSDP_SIZE), + ("RSDT", RSDT_OFFSET, RSDT_OFFSET + RSDT_SIZE), + ("FADT", FADT_OFFSET, FADT_OFFSET + FADT_SIZE), + ("DSDT", DSDT_OFFSET, DSDT_OFFSET + dsdt_size), + ("MADT", MADT_OFFSET, MADT_OFFSET + MADT_SIZE_1), + ]; + + for i in 0..tables.len() { + for j in (i + 1)..tables.len() { + let (name_a, start_a, end_a) = tables[i]; + let (name_b, start_b, end_b) = tables[j]; + assert!( + end_a <= start_b || end_b <= start_a, + "{} [{:#X}..{:#X}) overlaps {} [{:#X}..{:#X})", + name_a, + start_a, + end_a, + name_b, + start_b, + end_b + ); + } + } + } + + #[test] + fn test_all_tables_fit_in_region() { + let dsdt_size = DSDT_HEADER_SIZE + S5_AML.len(); + let last_table_end = MADT_OFFSET + MADT_SIZE_1; + assert!( + last_table_end <= ACPI_REGION_SIZE as usize, + "tables extend beyond region: {} > {}", + last_table_end, + ACPI_REGION_SIZE + ); + // Also verify DSDT doesn't extend into MADT. + assert!(DSDT_OFFSET + dsdt_size <= MADT_OFFSET); + } + + // ---- Multi-vCPU MADT tests ---- + + #[test] + fn test_madt_multi_vcpu_lapic_entries() { + let region = build_acpi_tables(TEST_BASE, 4); + let madt_sz = madt_size(4); + let madt = ®ion[MADT_OFFSET..MADT_OFFSET + madt_sz]; + + // Verify MADT length field matches. + let length = u32::from_le_bytes(madt[4..8].try_into().unwrap()); + assert_eq!(length, madt_sz as u32); + + // Verify checksum. + let sum: u8 = madt.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "MADT checksum must be zero for 4 vCPUs"); + + // Verify 4 LAPIC entries with correct IDs. + for i in 0..4u8 { + let off = MADT_HEADER_SIZE + (i as usize) * MADT_LAPIC_ENTRY_SIZE; + assert_eq!(madt[off], 0, "entry type: Local APIC for vCPU {}", i); + assert_eq!(madt[off + 1], 8, "entry length for vCPU {}", i); + assert_eq!(madt[off + 2], i, "ACPI Processor ID for vCPU {}", i); + assert_eq!(madt[off + 3], i, "APIC ID for vCPU {}", i); + let flags = u32::from_le_bytes(madt[off + 4..off + 8].try_into().unwrap()); + assert_eq!(flags, 1, "enabled flag for vCPU {}", i); + } + + // Verify IOAPIC entry follows the 4 LAPIC entries. + let ioapic_off = MADT_HEADER_SIZE + 4 * MADT_LAPIC_ENTRY_SIZE; + assert_eq!(madt[ioapic_off], 1, "entry type: I/O APIC"); + } + + #[test] + fn test_madt_size_scales_with_vcpus() { + assert_eq!(madt_size(1), MADT_SIZE_1); + assert_eq!( + madt_size(2), + MADT_SIZE_1 + MADT_LAPIC_ENTRY_SIZE, + "2 vCPUs adds one more LAPIC entry" + ); + assert_eq!( + madt_size(4), + MADT_SIZE_1 + 3 * MADT_LAPIC_ENTRY_SIZE, + "4 vCPUs adds three more LAPIC entries" + ); + } +} diff --git a/src/vmm/src/windows/boot/loader.rs b/src/vmm/src/windows/boot/loader.rs new file mode 100644 index 000000000..7145ff5ae --- /dev/null +++ b/src/vmm/src/windows/boot/loader.rs @@ -0,0 +1,632 @@ +//! Linux bzImage kernel loader. +//! +//! Parses a bzImage file, loads the protected-mode kernel into guest memory, +//! sets up page tables, GDT, boot parameters, and kernel command line. + +use super::super::error::{Result, WkrunError}; +use super::params::HDRS_MAGIC; + +#[cfg(any(target_os = "windows", test))] +use super::super::memory::{ + IOAPIC_MMIO_BASE, LAPIC_MMIO_BASE, LAPIC_MMIO_SIZE, MMIO_REGION_SIZE, VIRTIO_MMIO_BASE, +}; +#[cfg(any(target_os = "windows", test))] +use super::params::{E820Entry, E820_ACPI, E820_RAM, E820_RESERVED}; + +// These imports are only used by the Windows-only load_kernel() function. +#[cfg(target_os = "windows")] +use super::super::memory::{ + ACPI_START, CMDLINE_MAX_SIZE, CMDLINE_START, KERNEL_64BIT_ENTRY_OFFSET, KERNEL_START, + PDPT_START, PD_START, PML4_START, ZERO_PAGE_START, +}; +#[cfg(target_os = "windows")] +use super::super::types::{SpecialRegisters, StandardRegisters}; +#[cfg(target_os = "windows")] +use super::acpi; +#[cfg(target_os = "windows")] +use super::params::BootParams; +#[cfg(target_os = "windows")] +use super::setup::{build_gdt, build_page_tables, configure_boot_registers, gdt_bytes}; + +/// Loadflags bit: kernel was loaded high (at 0x100000). +#[cfg(any(target_os = "windows", test))] +const LOADED_HIGH: u8 = 0x01; + +/// Loadflags bit: can use heap (setup heap). +#[cfg(target_os = "windows")] +const CAN_USE_HEAP: u8 = 0x80; + +/// Parsed bzImage header information. +#[derive(Debug)] +pub struct KernelHeader { + /// Boot protocol version (e.g., 0x020F for 2.15). + pub protocol_version: u16, + /// Number of setup sectors (real-mode kernel). + pub setup_sects: u8, + /// Byte offset of the protected-mode kernel within the bzImage. + pub kernel_offset: usize, + /// Size of the protected-mode kernel in bytes. + pub kernel_size: usize, + /// Load flags from the setup header. + pub loadflags: u8, +} + +/// Parse a bzImage and extract header information. +/// +/// Validates the setup header magic ("HdrS") and protocol version, +/// then computes the offset and size of the protected-mode kernel. +pub fn parse_bzimage(kernel_image: &[u8]) -> Result { + // Minimum size: at least the setup header through version field (0x208). + if kernel_image.len() < 0x208 { + return Err(WkrunError::Boot(format!( + "kernel image too small: {} bytes (need at least {})", + kernel_image.len(), + 0x208 + ))); + } + + // Check "HdrS" magic at offset 0x202. + let header_magic = u32::from_le_bytes( + kernel_image[0x202..0x206] + .try_into() + .map_err(|_| WkrunError::Boot("failed to read header magic".into()))?, + ); + if header_magic != HDRS_MAGIC { + return Err(WkrunError::Boot(format!( + "invalid bzImage header magic: expected 0x{:08X} (HdrS), got 0x{:08X}", + HDRS_MAGIC, header_magic + ))); + } + + // Read boot protocol version at offset 0x206. + let protocol_version = u16::from_le_bytes( + kernel_image[0x206..0x208] + .try_into() + .map_err(|_| WkrunError::Boot("failed to read protocol version".into()))?, + ); + + // We require protocol version >= 2.06 for 64-bit boot. + if protocol_version < 0x0206 { + return Err(WkrunError::Boot(format!( + "boot protocol version 0x{:04X} too old (need >= 0x0206)", + protocol_version + ))); + } + + // Read setup_sects at offset 0x1F1. If 0, default to 4. + let mut setup_sects = kernel_image[0x1F1]; + if setup_sects == 0 { + setup_sects = 4; + } + + // Read loadflags at offset 0x211. + let loadflags = kernel_image[0x211]; + + // Protected-mode kernel starts after (setup_sects + 1) * 512 bytes. + // The "+1" accounts for the boot sector (first 512 bytes). + let kernel_offset = (setup_sects as usize + 1) * 512; + if kernel_offset >= kernel_image.len() { + return Err(WkrunError::Boot(format!( + "setup_sects {} puts kernel offset {} beyond image size {}", + setup_sects, + kernel_offset, + kernel_image.len() + ))); + } + + let kernel_size = kernel_image.len() - kernel_offset; + + Ok(KernelHeader { + protocol_version, + setup_sects, + kernel_offset, + kernel_size, + loadflags, + }) +} + +/// Build the E820 memory map for the guest. +/// +/// Creates a standard memory map with: +/// - Low memory (0 .. 0x9FC00) — 640KB conventional +/// - Reserved (0x9FC00 .. 0x100000) — BIOS area +/// - ACPI tables (acpi_base .. acpi_base + acpi_size) +/// - High memory (0x100000 .. ram_end) — main RAM +#[cfg(any(target_os = "windows", test))] +fn build_e820_map(ram_mib: u32, acpi_base: u64, acpi_size: u64) -> Vec { + let ram_bytes = (ram_mib as u64) * 1024 * 1024; + + let mut entries = Vec::new(); + + // Low memory: 0 to 640KB (conventional memory). + entries.push(E820Entry { + addr: 0, + size: 0x9FC00, + entry_type: E820_RAM, + _pad: 0, + }); + + // Reserved: 640KB to 1MB (BIOS, VGA, etc). + entries.push(E820Entry { + addr: 0x9FC00, + size: 0x100000 - 0x9FC00, + entry_type: E820_RESERVED, + _pad: 0, + }); + + // ACPI tables (within the BIOS reserved region). + entries.push(E820Entry { + addr: acpi_base, + size: acpi_size, + entry_type: E820_ACPI, + _pad: 0, + }); + + // High memory: 1MB to end of RAM. + // When RAM extends past device MMIO regions, split around reserved holes + // so the kernel doesn't try to use MMIO addresses as regular RAM. + // + // Two MMIO holes may exist: + // 1. Virtio MMIO: 0xD000_0000 .. 0xD020_0000 + // 2. APIC MMIO: 0xFEC0_0000 .. 0xFEE0_1000 (IOAPIC + LAPIC) + if ram_bytes > 0x100000 { + let apic_start = IOAPIC_MMIO_BASE; + let apic_end = LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE; + + if ram_bytes > VIRTIO_MMIO_BASE { + let mmio_end = VIRTIO_MMIO_BASE + MMIO_REGION_SIZE; + + // High memory below Virtio MMIO. + entries.push(E820Entry { + addr: 0x100000, + size: VIRTIO_MMIO_BASE - 0x100000, + entry_type: E820_RAM, + _pad: 0, + }); + + // Virtio MMIO region (reserved). + entries.push(E820Entry { + addr: VIRTIO_MMIO_BASE, + size: MMIO_REGION_SIZE, + entry_type: E820_RESERVED, + _pad: 0, + }); + + if ram_bytes > apic_start { + // RAM extends past APIC region — add APIC hole. + // RAM between Virtio MMIO end and APIC start. + entries.push(E820Entry { + addr: mmio_end, + size: apic_start - mmio_end, + entry_type: E820_RAM, + _pad: 0, + }); + + // APIC MMIO region (reserved: IOAPIC + LAPIC). + entries.push(E820Entry { + addr: apic_start, + size: apic_end - apic_start, + entry_type: E820_RESERVED, + _pad: 0, + }); + + // RAM above APIC region. + if ram_bytes > apic_end { + entries.push(E820Entry { + addr: apic_end, + size: ram_bytes - apic_end, + entry_type: E820_RAM, + _pad: 0, + }); + } + } else if ram_bytes > mmio_end { + // RAM between Virtio MMIO and APIC — no APIC hole needed. + entries.push(E820Entry { + addr: mmio_end, + size: ram_bytes - mmio_end, + entry_type: E820_RAM, + _pad: 0, + }); + } + } else { + entries.push(E820Entry { + addr: 0x100000, + size: ram_bytes - 0x100000, + entry_type: E820_RAM, + _pad: 0, + }); + } + } + + entries +} + +/// Load a Linux bzImage kernel into guest memory and configure for boot. +/// +/// This performs the complete boot setup: +/// 1. Parse the bzImage header +/// 2. Copy the protected-mode kernel to KERNEL_START (0x100000) +/// 3. Write page tables (PML4, PDPT, PD) to guest memory +/// 4. Write GDT to guest memory +/// 5. Write boot parameters (zero page) with E820 map +/// 6. Write kernel command line +/// 7. Optionally load initrd into high guest memory +/// 8. Configure vCPU registers for 64-bit long mode entry +/// +/// Returns the initial vCPU register state. +#[cfg(target_os = "windows")] +pub fn load_kernel( + guest_mem: &super::super::memory::GuestMemory, + kernel_image: &[u8], + cmdline: &str, + ram_mib: u32, + num_vcpus: u8, +) -> Result<(StandardRegisters, SpecialRegisters)> { + load_kernel_with_initrd(guest_mem, kernel_image, cmdline, ram_mib, None, num_vcpus) +} + +/// Load a Linux bzImage kernel with an optional initrd. +#[cfg(target_os = "windows")] +pub fn load_kernel_with_initrd( + guest_mem: &super::super::memory::GuestMemory, + kernel_image: &[u8], + cmdline: &str, + ram_mib: u32, + initrd: Option<&[u8]>, + num_vcpus: u8, +) -> Result<(StandardRegisters, SpecialRegisters)> { + let header = parse_bzimage(kernel_image)?; + + // Validate kernel fits in guest memory. + let kernel_end = KERNEL_START + header.kernel_size as u64; + let ram_bytes = (ram_mib as u64) * 1024 * 1024; + if kernel_end > ram_bytes { + return Err(WkrunError::Boot(format!( + "kernel ({} bytes) doesn't fit in {} MiB RAM (needs at least 0x{:X} bytes)", + header.kernel_size, ram_mib, kernel_end + ))); + } + + // Validate command line fits. + let cmdline_bytes = cmdline.as_bytes(); + if cmdline_bytes.len() as u64 + 1 > CMDLINE_MAX_SIZE { + return Err(WkrunError::Boot(format!( + "kernel command line too long: {} bytes (max {})", + cmdline_bytes.len(), + CMDLINE_MAX_SIZE - 1 + ))); + } + + // 1. Copy protected-mode kernel to KERNEL_START. + let kernel_data = &kernel_image[header.kernel_offset..]; + guest_mem.write_at_addr(KERNEL_START, kernel_data)?; + + // 2. Write page tables. + let page_tables = build_page_tables(); + guest_mem.write_at_addr(PML4_START, page_tables.pml4_bytes())?; + guest_mem.write_at_addr(PDPT_START, page_tables.pdpt_bytes())?; + for i in 0..4 { + guest_mem.write_at_addr(PD_START + i as u64 * 0x1000, page_tables.pd_bytes(i))?; + } + + // 3. Write GDT. + let gdt = build_gdt(); + let gdt_data = gdt_bytes(&gdt); + // GDT_ADDR is 0x500, defined in setup.rs. Use the constant from memory layout. + guest_mem.write_at_addr(0x500, &gdt_data)?; + + // 4. Build and write boot parameters (zero page). + let mut boot_params = BootParams::new(); + boot_params.set_boot_flag(); + boot_params.set_header_magic(); + boot_params.set_version(header.protocol_version); + boot_params.set_loader_type(0xFF); // Undefined bootloader + boot_params.set_loadflags(LOADED_HIGH | CAN_USE_HEAP); + + // Copy relevant fields from the kernel's own setup header into boot_params. + // The kernel reads some fields back from the zero page that it originally set. + copy_setup_header(&mut boot_params, kernel_image, &header); + + // Set kernel command line. + boot_params.set_cmdline_ptr(CMDLINE_START as u32); + boot_params.set_cmdline_size(cmdline_bytes.len() as u32); + + // Write ACPI tables to guest memory and tell the kernel where to find them. + let acpi_data = acpi::build_acpi_tables(ACPI_START, num_vcpus); + guest_mem.write_at_addr(ACPI_START, &acpi_data)?; + boot_params.set_acpi_rsdp_addr(ACPI_START); + + // Write Intel MP table for SMP CPU discovery. + // The kernel (CONFIG_X86_MPPARSE=y) scans 0x9FC00-0x9FFFF for the MP + // Floating Pointer Structure. This is needed because CONFIG_ACPI is not + // enabled in the kernel, so the MADT-based CPU discovery doesn't work. + let mp_data = super::mp_table::build_mp_tables(num_vcpus); + guest_mem.write_at_addr(super::mp_table::MP_FPS_ADDR, &mp_data)?; + + // Set E820 memory map. + let e820_map = build_e820_map(ram_mib, ACPI_START, acpi::ACPI_REGION_SIZE); + boot_params.set_e820_map(&e820_map); + + // Load initrd if provided. Place at the end of RAM (page-aligned). + if let Some(initrd_data) = initrd { + if !initrd_data.is_empty() { + let initrd_size = initrd_data.len() as u64; + // Align initrd to end of RAM, start at page boundary. + let initrd_end = ram_bytes; + let initrd_start = (initrd_end - initrd_size) & !0xFFF; // Page-align down + + if initrd_start < kernel_end { + return Err(WkrunError::Boot(format!( + "initrd ({} bytes) overlaps with kernel at 0x{:X} (initrd at 0x{:X})", + initrd_size, kernel_end, initrd_start + ))); + } + + guest_mem.write_at_addr(initrd_start, initrd_data)?; + boot_params.set_ramdisk(initrd_start as u32, initrd_data.len() as u32); + } + } + + guest_mem.write_at_addr(ZERO_PAGE_START, &boot_params.data)?; + + // 5. Write kernel command line (null-terminated). + let mut cmdline_buf = cmdline_bytes.to_vec(); + cmdline_buf.push(0); // null terminator + guest_mem.write_at_addr(CMDLINE_START, &cmdline_buf)?; + + // 6. Configure vCPU registers for 64-bit long mode. + // The 64-bit entry point (startup_64) is at KERNEL_START + 0x200. + Ok(configure_boot_registers( + KERNEL_START + KERNEL_64BIT_ENTRY_OFFSET, + )) +} + +/// Copy select fields from the kernel's setup header into boot_params. +/// +/// The kernel expects certain fields in the zero page to match what it +/// originally placed in its own setup header. We copy the fields that +/// the kernel reads back during early boot. +#[cfg(target_os = "windows")] +fn copy_setup_header(boot_params: &mut BootParams, kernel_image: &[u8], header: &KernelHeader) { + // setup_sects at offset 0x1F1. + boot_params.data[0x1F1] = header.setup_sects; + + // Copy the setup header region (0x1F1..0x268) from the kernel image. + // This includes fields like code32_start, kernel_alignment, init_size, etc. + // that the kernel reads back during boot. + let header_end = 0x268.min(kernel_image.len()); + if header_end > 0x1F1 { + let src = &kernel_image[0x1F1..header_end]; + boot_params.data[0x1F1..header_end].copy_from_slice(src); + } + + // Override the fields we explicitly set (they take precedence over what + // was in the original kernel header). + boot_params.set_boot_flag(); + boot_params.set_header_magic(); + boot_params.set_version(header.protocol_version); + boot_params.set_loader_type(0xFF); + boot_params.set_loadflags(LOADED_HIGH | CAN_USE_HEAP); +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a minimal valid bzImage header for testing. + fn make_test_bzimage(setup_sects: u8, protocol_version: u16, kernel_payload: &[u8]) -> Vec { + // Total real-mode size: (setup_sects + 1) * 512 + let real_mode_size = (setup_sects as usize + 1) * 512; + let mut image = vec![0u8; real_mode_size + kernel_payload.len()]; + + // setup_sects at 0x1F1 + image[0x1F1] = setup_sects; + + // "HdrS" magic at 0x202 + image[0x202..0x206].copy_from_slice(&HDRS_MAGIC.to_le_bytes()); + + // Protocol version at 0x206 + image[0x206..0x208].copy_from_slice(&protocol_version.to_le_bytes()); + + // Loadflags at 0x211 (LOADED_HIGH) + image[0x211] = LOADED_HIGH; + + // Copy kernel payload after real-mode code + image[real_mode_size..].copy_from_slice(kernel_payload); + + image + } + + #[test] + fn test_parse_bzimage_valid() { + let kernel_payload = vec![0xCC; 1024]; // 1KB of int3 + let image = make_test_bzimage(4, 0x020F, &kernel_payload); + + let header = parse_bzimage(&image).expect("should parse valid bzImage"); + assert_eq!(header.protocol_version, 0x020F); + assert_eq!(header.setup_sects, 4); + assert_eq!(header.kernel_offset, (4 + 1) * 512); + assert_eq!(header.kernel_size, 1024); + assert_eq!(header.loadflags & LOADED_HIGH, LOADED_HIGH); + } + + #[test] + fn test_parse_bzimage_setup_sects_zero_defaults_to_4() { + // setup_sects=0 defaults to 4, so kernel_offset = (4+1)*512 = 2560. + // Build image large enough to accommodate this. + let mut image = vec![0u8; (4 + 1) * 512 + 512]; // real-mode + kernel + image[0x1F1] = 0; // setup_sects = 0 + image[0x202..0x206].copy_from_slice(&HDRS_MAGIC.to_le_bytes()); + image[0x206..0x208].copy_from_slice(&0x0206u16.to_le_bytes()); + image[0x211] = LOADED_HIGH; + + let header = parse_bzimage(&image).expect("should parse with setup_sects=0"); + assert_eq!(header.setup_sects, 4); // defaulted from 0 + assert_eq!(header.kernel_offset, (4 + 1) * 512); + } + + #[test] + fn test_parse_bzimage_too_small() { + let image = vec![0u8; 100]; // Way too small + let err = parse_bzimage(&image).unwrap_err(); + assert!( + err.to_string().contains("too small"), + "unexpected error: {}", + err + ); + } + + #[test] + fn test_parse_bzimage_bad_magic() { + let mut image = vec![0u8; 0x300]; + image[0x1F1] = 1; + // Don't set "HdrS" magic + let err = parse_bzimage(&image).unwrap_err(); + assert!( + err.to_string().contains("header magic"), + "unexpected error: {}", + err + ); + } + + #[test] + fn test_parse_bzimage_old_protocol() { + let mut image = vec![0u8; 0x300]; + image[0x1F1] = 1; + image[0x202..0x206].copy_from_slice(&HDRS_MAGIC.to_le_bytes()); + image[0x206..0x208].copy_from_slice(&0x0200u16.to_le_bytes()); // too old + let err = parse_bzimage(&image).unwrap_err(); + assert!( + err.to_string().contains("too old"), + "unexpected error: {}", + err + ); + } + + #[test] + fn test_parse_bzimage_kernel_offset_beyond_image() { + let mut image = vec![0u8; 0x300]; // only ~768 bytes + image[0x1F1] = 10; // setup_sects=10 → offset = 11*512 = 5632 > 768 + image[0x202..0x206].copy_from_slice(&HDRS_MAGIC.to_le_bytes()); + image[0x206..0x208].copy_from_slice(&0x0206u16.to_le_bytes()); + let err = parse_bzimage(&image).unwrap_err(); + assert!( + err.to_string().contains("beyond image size"), + "unexpected error: {}", + err + ); + } + + // ACPI constants for test assertions. + const TEST_ACPI_BASE: u64 = 0xE0000; + const TEST_ACPI_SIZE: u64 = 0x200; + + #[test] + fn test_build_e820_map_256mb() { + let map = build_e820_map(256, TEST_ACPI_BASE, TEST_ACPI_SIZE); + assert_eq!(map.len(), 4); + + // Low memory: 0 .. 640KB + assert_eq!(map[0].addr, 0); + assert_eq!(map[0].size, 0x9FC00); + assert_eq!(map[0].entry_type, E820_RAM); + + // Reserved: 640KB .. 1MB + assert_eq!(map[1].addr, 0x9FC00); + assert_eq!(map[1].entry_type, E820_RESERVED); + + // ACPI tables + assert_eq!(map[2].addr, TEST_ACPI_BASE); + assert_eq!(map[2].size, TEST_ACPI_SIZE); + assert_eq!(map[2].entry_type, E820_ACPI); + + // High memory: 1MB .. 256MB + assert_eq!(map[3].addr, 0x100000); + assert_eq!(map[3].size, 256 * 1024 * 1024 - 0x100000); + assert_eq!(map[3].entry_type, E820_RAM); + } + + #[test] + fn test_build_e820_map_1mb_no_high_memory() { + // With only 1MB of RAM, high memory region should be empty (1MB - 1MB = 0). + let map = build_e820_map(1, TEST_ACPI_BASE, TEST_ACPI_SIZE); + assert_eq!( + map.len(), + 3, + "1MB RAM should only have low + reserved + ACPI" + ); + } + + #[test] + fn test_build_e820_map_4096mb_has_mmio_and_apic_holes() { + let map = build_e820_map(4096, TEST_ACPI_BASE, TEST_ACPI_SIZE); + // Low + BIOS reserved + ACPI + high1 + VIRTIO reserved + high2 + // + APIC reserved + high3 = 8 entries. + assert_eq!( + map.len(), + 8, + "4GB RAM should have VIRTIO + APIC holes: {:?}", + map + ); + + // Low memory. + assert_eq!(map[0].addr, 0); + assert_eq!(map[0].entry_type, E820_RAM); + + // BIOS reserved. + assert_eq!(map[1].addr, 0x9FC00); + assert_eq!(map[1].entry_type, E820_RESERVED); + + // ACPI tables. + assert_eq!(map[2].addr, TEST_ACPI_BASE); + assert_eq!(map[2].size, TEST_ACPI_SIZE); + assert_eq!(map[2].entry_type, E820_ACPI); + + // High memory below Virtio MMIO. + assert_eq!(map[3].addr, 0x100000); + assert_eq!(map[3].size, VIRTIO_MMIO_BASE - 0x100000); + assert_eq!(map[3].entry_type, E820_RAM); + + // Virtio MMIO reserved region. + assert_eq!(map[4].addr, VIRTIO_MMIO_BASE); + assert_eq!(map[4].size, MMIO_REGION_SIZE); + assert_eq!(map[4].entry_type, E820_RESERVED); + + // High memory between Virtio MMIO and APIC. + let mmio_end = VIRTIO_MMIO_BASE + MMIO_REGION_SIZE; + assert_eq!(map[5].addr, mmio_end); + assert_eq!(map[5].size, IOAPIC_MMIO_BASE - mmio_end); + assert_eq!(map[5].entry_type, E820_RAM); + + // APIC reserved region (IOAPIC + LAPIC). + let apic_end = LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE; + assert_eq!(map[6].addr, IOAPIC_MMIO_BASE); + assert_eq!(map[6].size, apic_end - IOAPIC_MMIO_BASE); + assert_eq!(map[6].entry_type, E820_RESERVED); + + // High memory above APIC. + assert_eq!(map[7].addr, apic_end); + assert_eq!(map[7].size, 4096 * 1024 * 1024 - apic_end); + assert_eq!(map[7].entry_type, E820_RAM); + } + + #[test] + fn test_build_e820_map_no_hole_below_mmio() { + // 3072 MB = 3GB < VIRTIO_MMIO_BASE (3.25GB) — no hole needed. + let map = build_e820_map(3072, TEST_ACPI_BASE, TEST_ACPI_SIZE); + assert_eq!(map.len(), 4, "3GB RAM should not have MMIO hole"); + assert_eq!(map[3].addr, 0x100000); + assert_eq!(map[3].size, 3072 * 1024 * 1024 - 0x100000); + assert_eq!(map[3].entry_type, E820_RAM); + } + + #[test] + fn test_build_e820_map_has_acpi_entry() { + let map = build_e820_map(256, TEST_ACPI_BASE, TEST_ACPI_SIZE); + let acpi_entry = map.iter().find(|e| e.entry_type == E820_ACPI); + assert!(acpi_entry.is_some(), "E820 map must contain ACPI entry"); + let entry = acpi_entry.unwrap(); + assert_eq!(entry.addr, TEST_ACPI_BASE); + assert_eq!(entry.size, TEST_ACPI_SIZE); + } +} diff --git a/src/vmm/src/windows/boot/mod.rs b/src/vmm/src/windows/boot/mod.rs new file mode 100644 index 000000000..17fb4b465 --- /dev/null +++ b/src/vmm/src/windows/boot/mod.rs @@ -0,0 +1,7 @@ +//! Linux kernel boot support for x86_64 (Windows WHPX backend). + +pub mod acpi; +pub mod loader; +pub mod mp_table; +pub mod params; +pub mod setup; diff --git a/src/vmm/src/windows/boot/mp_table.rs b/src/vmm/src/windows/boot/mp_table.rs new file mode 100644 index 000000000..c461069f3 --- /dev/null +++ b/src/vmm/src/windows/boot/mp_table.rs @@ -0,0 +1,272 @@ +//! Intel MultiProcessor Specification table generation. +//! +//! Generates the MP Floating Pointer Structure and MP Configuration Table +//! so the Linux kernel can discover multiple vCPUs when CONFIG_ACPI is +//! not enabled (CONFIG_X86_MPPARSE=y is sufficient). +//! +//! The kernel scans for the MP FPS in: +//! - First 1KB of EBDA +//! - Last 1KB of base memory (0x9FC00-0x9FFFF) +//! - BIOS ROM area (0xF0000-0xFFFFF) + +/// Guest physical address for the MP Floating Pointer Structure. +/// Placed at 0x9FC00 (start of the last 1KB of base memory). +pub const MP_FPS_ADDR: u64 = 0x9_FC00; + +/// Guest physical address for the MP Configuration Table. +/// Placed right after the 16-byte FPS. +const MP_TABLE_ADDR: u64 = MP_FPS_ADDR + 16; + +/// MP FPS size (always 16 bytes). +const FPS_SIZE: usize = 16; + +/// MP Configuration Table header size. +const MP_HEADER_SIZE: usize = 44; + +/// Processor entry size (type 0). +const PROC_ENTRY_SIZE: usize = 20; + +/// I/O APIC entry size (type 2). +const IOAPIC_ENTRY_SIZE: usize = 8; + +/// LAPIC base address (must match memory.rs and acpi.rs). +const LAPIC_BASE: u32 = 0xFEE0_0000; + +/// IOAPIC base address (must match memory.rs and acpi.rs). +const IOAPIC_BASE: u32 = 0xFEC0_0000; + +/// Total size needed for the MP table region. +pub fn mp_region_size(num_vcpus: u8) -> usize { + FPS_SIZE + MP_HEADER_SIZE + (num_vcpus as usize) * PROC_ENTRY_SIZE + IOAPIC_ENTRY_SIZE +} + +/// Build the MP Floating Pointer Structure (16 bytes). +/// +/// Placed at `fps_addr`, points to the MP Configuration Table at `table_addr`. +fn build_fps(fps: &mut [u8], table_addr: u32) { + // Signature "_MP_" + fps[0..4].copy_from_slice(b"_MP_"); + // Physical Address of MP Configuration Table + fps[4..8].copy_from_slice(&table_addr.to_le_bytes()); + // Length in 16-byte paragraphs (always 1) + fps[8] = 1; + // MP Specification revision (1.4) + fps[9] = 4; + // Checksum (computed below) + // fps[10] = checksum + // Feature bytes 1-5 (all 0 = use MP config table) + fps[11] = 0; + fps[12] = 0; + fps[13] = 0; + fps[14] = 0; + fps[15] = 0; + + mp_checksum(fps, 10); +} + +/// Build the MP Configuration Table. +/// +/// Contains: +/// - 44-byte header +/// - N processor entries (20 bytes each) +/// - 1 I/O APIC entry (8 bytes) +fn build_mp_config_table(table: &mut [u8], num_vcpus: u8) { + let entry_count = num_vcpus as u16 + 1; // N processors + 1 I/O APIC + let base_table_length = + MP_HEADER_SIZE + (num_vcpus as usize) * PROC_ENTRY_SIZE + IOAPIC_ENTRY_SIZE; + + // ---- Header (44 bytes) ---- + table[0..4].copy_from_slice(b"PCMP"); // Signature + table[4..6].copy_from_slice(&(base_table_length as u16).to_le_bytes()); // Base Table Length + table[6] = 4; // Spec Revision (1.4) + // table[7] = checksum (computed below) + table[8..16].copy_from_slice(b"BOXLTE\0\0"); // OEM ID (8 bytes) + table[16..28].copy_from_slice(b"BOXLITE-VM\0\0"); // Product ID (12 bytes) + // OEM Table Pointer (offset 28, 4 bytes) = 0 + // OEM Table Size (offset 32, 2 bytes) = 0 + table[34..36].copy_from_slice(&entry_count.to_le_bytes()); // Entry Count + table[36..40].copy_from_slice(&LAPIC_BASE.to_le_bytes()); // Local APIC Address + // Extended Table Length (offset 40, 2 bytes) = 0 + // Extended Table Checksum (offset 42, 1 byte) = 0 + // Reserved (offset 43, 1 byte) = 0 + + // ---- Processor entries (type 0, 20 bytes each) ---- + let mut off = MP_HEADER_SIZE; + for i in 0..num_vcpus { + table[off] = 0; // Entry type: Processor + table[off + 1] = i; // Local APIC ID + table[off + 2] = 0x14; // Local APIC Version + // CPU Flags: bit 0 = EN (usable), bit 1 = BP (bootstrap processor) + table[off + 3] = if i == 0 { 0x03 } else { 0x01 }; // BSP=3, AP=1 + // CPU Signature (4 bytes) — use a generic Family 6 Model signature. + // This doesn't need to match the host exactly; the kernel reads CPUID directly. + table[off + 4..off + 8].copy_from_slice(&0x0006_0000u32.to_le_bytes()); + // Feature Flags (4 bytes) — basic x86-64 features. + table[off + 8..off + 12].copy_from_slice(&0x0000_0000u32.to_le_bytes()); + // Reserved (8 bytes) = 0 + off += PROC_ENTRY_SIZE; + } + + // ---- I/O APIC entry (type 2, 8 bytes) ---- + table[off] = 2; // Entry type: I/O APIC + table[off + 1] = num_vcpus; // I/O APIC ID (after all LAPIC IDs) + table[off + 2] = 0x20; // I/O APIC Version + table[off + 3] = 0x01; // Flags: EN (enabled) + table[off + 4..off + 8].copy_from_slice(&IOAPIC_BASE.to_le_bytes()); // I/O APIC Address + + // Compute header checksum over the entire base table. + mp_checksum(&mut table[..base_table_length], 7); +} + +/// Build the complete MP table region (FPS + Configuration Table). +/// +/// Returns a `Vec` that should be written to guest memory at `MP_FPS_ADDR`. +pub fn build_mp_tables(num_vcpus: u8) -> Vec { + let total_size = mp_region_size(num_vcpus); + let mut region = vec![0u8; total_size]; + + // Build the FPS (first 16 bytes), pointing to the config table. + build_fps(&mut region[..FPS_SIZE], MP_TABLE_ADDR as u32); + + // Build the MP Configuration Table (after the FPS). + build_mp_config_table(&mut region[FPS_SIZE..], num_vcpus); + + region +} + +/// Compute MP checksum and store at `checksum_offset`. +/// +/// The checksum byte is set so the sum of all bytes equals zero (mod 256). +fn mp_checksum(data: &mut [u8], checksum_offset: usize) { + data[checksum_offset] = 0; + let sum: u8 = data.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + data[checksum_offset] = 0u8.wrapping_sub(sum); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fps_signature_and_checksum() { + let region = build_mp_tables(2); + let fps = ®ion[..FPS_SIZE]; + + assert_eq!(&fps[0..4], b"_MP_"); + let sum: u8 = fps.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "FPS checksum must be zero"); + } + + #[test] + fn test_fps_points_to_table() { + let region = build_mp_tables(2); + let fps = ®ion[..FPS_SIZE]; + + let table_addr = u32::from_le_bytes(fps[4..8].try_into().unwrap()); + assert_eq!(table_addr, MP_TABLE_ADDR as u32); + } + + #[test] + fn test_fps_revision() { + let region = build_mp_tables(2); + assert_eq!(region[8], 1, "FPS length must be 1 paragraph"); + assert_eq!(region[9], 4, "FPS revision must be 1.4"); + } + + #[test] + fn test_mp_table_signature_and_checksum() { + let region = build_mp_tables(2); + let table_start = FPS_SIZE; + let table_len = MP_HEADER_SIZE + 2 * PROC_ENTRY_SIZE + IOAPIC_ENTRY_SIZE; + let table = ®ion[table_start..table_start + table_len]; + + assert_eq!(&table[0..4], b"PCMP"); + + let length = u16::from_le_bytes(table[4..6].try_into().unwrap()); + assert_eq!(length, table_len as u16); + + let sum: u8 = table.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)); + assert_eq!(sum, 0, "MP table checksum must be zero"); + } + + #[test] + fn test_mp_table_processor_entries() { + let region = build_mp_tables(4); + let table_start = FPS_SIZE; + + // Check entry count + let entry_count = u16::from_le_bytes( + region[table_start + 34..table_start + 36] + .try_into() + .unwrap(), + ); + assert_eq!(entry_count, 5, "4 processors + 1 I/O APIC"); + + // Check processor entries + for i in 0..4u8 { + let off = table_start + MP_HEADER_SIZE + (i as usize) * PROC_ENTRY_SIZE; + assert_eq!(region[off], 0, "entry type: Processor for vCPU {}", i); + assert_eq!(region[off + 1], i, "APIC ID for vCPU {}", i); + assert_eq!(region[off + 2], 0x14, "APIC version for vCPU {}", i); + if i == 0 { + assert_eq!(region[off + 3], 0x03, "BSP flags (EN + BP)"); + } else { + assert_eq!(region[off + 3], 0x01, "AP flags (EN only)"); + } + } + } + + #[test] + fn test_mp_table_ioapic_entry() { + let region = build_mp_tables(2); + let table_start = FPS_SIZE; + let ioapic_off = table_start + MP_HEADER_SIZE + 2 * PROC_ENTRY_SIZE; + + assert_eq!(region[ioapic_off], 2, "entry type: I/O APIC"); + assert_eq!(region[ioapic_off + 1], 2, "I/O APIC ID"); + assert_eq!(region[ioapic_off + 3], 0x01, "enabled flag"); + + let addr = u32::from_le_bytes(region[ioapic_off + 4..ioapic_off + 8].try_into().unwrap()); + assert_eq!(addr, IOAPIC_BASE); + } + + #[test] + fn test_mp_table_lapic_address() { + let region = build_mp_tables(1); + let table_start = FPS_SIZE; + + let lapic_addr = u32::from_le_bytes( + region[table_start + 36..table_start + 40] + .try_into() + .unwrap(), + ); + assert_eq!(lapic_addr, LAPIC_BASE); + } + + #[test] + fn test_single_vcpu() { + let region = build_mp_tables(1); + let total = mp_region_size(1); + assert_eq!(region.len(), total); + + let table_start = FPS_SIZE; + let entry_count = u16::from_le_bytes( + region[table_start + 34..table_start + 36] + .try_into() + .unwrap(), + ); + assert_eq!(entry_count, 2, "1 processor + 1 I/O APIC"); + } + + #[test] + fn test_region_fits_in_base_memory() { + // MP tables for up to 16 vCPUs must fit in the scan area. + let max_size = mp_region_size(16); + // FPS at 0x9FC00, scan area is 0x9FC00-0x9FFFF (1024 bytes). + assert!( + max_size <= 1024, + "MP tables for 16 vCPUs ({} bytes) exceed scan area", + max_size, + ); + } +} diff --git a/src/vmm/src/windows/boot/params.rs b/src/vmm/src/windows/boot/params.rs new file mode 100644 index 000000000..8562988ab --- /dev/null +++ b/src/vmm/src/windows/boot/params.rs @@ -0,0 +1,270 @@ +//! Linux boot_params (zero page) structure. +//! +//! Subset of the Linux boot protocol's boot_params structure +//! needed for direct bzImage boot. + +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; + +/// E820 memory map entry type constants. +pub const E820_RAM: u32 = 1; +pub const E820_RESERVED: u32 = 2; +pub const E820_ACPI: u32 = 3; + +/// Linux boot protocol magic number. +pub const BOOT_MAGIC: u16 = 0xAA55; + +/// Header magic "HdrS". +pub const HDRS_MAGIC: u32 = 0x5372_6448; + +/// Minimum boot protocol version we support (2.06+). +pub const MIN_BOOT_PROTOCOL: u16 = 0x0206; + +/// E820 memory map entry. +#[repr(C)] +#[derive(Debug, Default, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)] +pub struct E820Entry { + pub addr: u64, + pub size: u64, + pub entry_type: u32, + pub _pad: u32, +} + +/// Minimal subset of Linux setup_header structure. +/// Located at offset 0x1F1 in the zero page. +#[repr(C, packed)] +#[derive(Debug, Default, Clone, Copy)] +pub struct SetupHeader { + pub setup_sects: u8, + pub root_flags: u16, + pub syssize: u32, + pub ram_size: u16, + pub vid_mode: u16, + pub root_dev: u16, + pub boot_flag: u16, + pub jump: u16, + pub header: u32, + pub version: u16, + pub realmode_swtch: u32, + pub start_sys_seg: u16, + pub kernel_version: u16, + pub type_of_loader: u8, + pub loadflags: u8, + pub setup_move_size: u16, + pub code32_start: u32, + pub ramdisk_image: u32, + pub ramdisk_size: u32, + pub bootsect_kludge: u32, + pub heap_end_ptr: u16, + pub ext_loader_ver: u8, + pub ext_loader_type: u8, + pub cmd_line_ptr: u32, + pub initrd_addr_max: u32, + pub kernel_alignment: u32, + pub relocatable_kernel: u8, + pub min_alignment: u8, + pub xloadflags: u16, + pub cmdline_size: u32, + pub hardware_subarch: u32, + pub hardware_subarch_data: u64, + pub payload_offset: u32, + pub payload_length: u32, + pub setup_data: u64, + pub pref_address: u64, + pub init_size: u32, + pub handover_offset: u32, +} + +/// Boot parameters (zero page) — the key structure passed to the Linux kernel. +pub struct BootParams { + /// The raw 4096-byte zero page buffer. + pub data: [u8; 4096], +} + +impl Default for BootParams { + fn default() -> Self { + BootParams { data: [0u8; 4096] } + } +} + +impl BootParams { + /// Create a new BootParams with default values. + pub fn new() -> Self { + Self::default() + } + + /// Set the E820 memory map. + pub fn set_e820_map(&mut self, entries: &[E820Entry]) { + let count = entries.len().min(128) as u8; + self.data[0x1E8] = count; + + let base_offset = 0x2D0; + for (i, entry) in entries.iter().take(128).enumerate() { + let offset = base_offset + i * 20; + self.data[offset..offset + 8].copy_from_slice(&entry.addr.to_le_bytes()); + self.data[offset + 8..offset + 16].copy_from_slice(&entry.size.to_le_bytes()); + self.data[offset + 16..offset + 20].copy_from_slice(&entry.entry_type.to_le_bytes()); + } + } + + /// Set the command line pointer. + pub fn set_cmdline_ptr(&mut self, addr: u32) { + self.data[0x228..0x22C].copy_from_slice(&addr.to_le_bytes()); + } + + /// Set the command line size. + pub fn set_cmdline_size(&mut self, size: u32) { + self.data[0x238..0x23C].copy_from_slice(&size.to_le_bytes()); + } + + /// Set the boot flag (must be 0xAA55). + pub fn set_boot_flag(&mut self) { + self.data[0x1FE..0x200].copy_from_slice(&BOOT_MAGIC.to_le_bytes()); + } + + /// Set the setup header magic ("HdrS"). + pub fn set_header_magic(&mut self) { + self.data[0x202..0x206].copy_from_slice(&HDRS_MAGIC.to_le_bytes()); + } + + /// Set the boot protocol version. + pub fn set_version(&mut self, version: u16) { + self.data[0x206..0x208].copy_from_slice(&version.to_le_bytes()); + } + + /// Set the type_of_loader field (0xFF = undefined bootloader). + pub fn set_loader_type(&mut self, loader_type: u8) { + self.data[0x210] = loader_type; + } + + /// Set load flags. + pub fn set_loadflags(&mut self, flags: u8) { + self.data[0x211] = flags; + } + + /// Set the ramdisk image address. + pub fn set_ramdisk(&mut self, addr: u32, size: u32) { + self.data[0x218..0x21C].copy_from_slice(&addr.to_le_bytes()); + self.data[0x21C..0x220].copy_from_slice(&size.to_le_bytes()); + } + + /// Set the ACPI RSDP physical address (boot protocol 2.14+, offset 0x070). + /// + /// When set, the kernel uses this address directly instead of scanning + /// the BIOS ROM area (0xE0000-0xFFFFF) for the RSDP signature. + /// For older kernels (protocol < 2.14), this field is padding and ignored. + pub fn set_acpi_rsdp_addr(&mut self, addr: u64) { + self.data[0x070..0x078].copy_from_slice(&addr.to_le_bytes()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_boot_params_default_is_zeroed() { + let params = BootParams::new(); + assert!(params.data.iter().all(|&b| b == 0)); + } + + #[test] + fn test_boot_params_set_boot_flag() { + let mut params = BootParams::new(); + params.set_boot_flag(); + let flag = u16::from_le_bytes([params.data[0x1FE], params.data[0x1FF]]); + assert_eq!(flag, BOOT_MAGIC); + } + + #[test] + fn test_boot_params_set_header_magic() { + let mut params = BootParams::new(); + params.set_header_magic(); + let magic = u32::from_le_bytes([ + params.data[0x202], + params.data[0x203], + params.data[0x204], + params.data[0x205], + ]); + assert_eq!(magic, HDRS_MAGIC); + } + + #[test] + fn test_boot_params_set_cmdline() { + let mut params = BootParams::new(); + params.set_cmdline_ptr(0x20000); + params.set_cmdline_size(256); + + let ptr = u32::from_le_bytes(params.data[0x228..0x22C].try_into().unwrap()); + let size = u32::from_le_bytes(params.data[0x238..0x23C].try_into().unwrap()); + assert_eq!(ptr, 0x20000); + assert_eq!(size, 256); + } + + #[test] + fn test_boot_params_e820_map() { + let mut params = BootParams::new(); + let entries = vec![ + E820Entry { + addr: 0, + size: 0x9FC00, + entry_type: E820_RAM, + _pad: 0, + }, + E820Entry { + addr: 0x100000, + size: 255 * 1024 * 1024, + entry_type: E820_RAM, + _pad: 0, + }, + ]; + + params.set_e820_map(&entries); + assert_eq!(params.data[0x1E8], 2); + + let addr = u64::from_le_bytes(params.data[0x2D0..0x2D8].try_into().unwrap()); + let size = u64::from_le_bytes(params.data[0x2D8..0x2E0].try_into().unwrap()); + let etype = u32::from_le_bytes(params.data[0x2E0..0x2E4].try_into().unwrap()); + assert_eq!(addr, 0); + assert_eq!(size, 0x9FC00); + assert_eq!(etype, E820_RAM); + } + + #[test] + fn test_e820_entry_size() { + assert_eq!(std::mem::size_of::(), 24); + } + + #[test] + fn test_boot_params_loader_type() { + let mut params = BootParams::new(); + params.set_loader_type(0xFF); + assert_eq!(params.data[0x210], 0xFF); + } + + #[test] + fn test_boot_params_ramdisk() { + let mut params = BootParams::new(); + params.set_ramdisk(0x1000000, 0x500000); + + let addr = u32::from_le_bytes(params.data[0x218..0x21C].try_into().unwrap()); + let size = u32::from_le_bytes(params.data[0x21C..0x220].try_into().unwrap()); + assert_eq!(addr, 0x1000000); + assert_eq!(size, 0x500000); + } + + #[test] + fn test_boot_params_acpi_rsdp_addr() { + let mut params = BootParams::new(); + params.set_acpi_rsdp_addr(0xE0000); + + let addr = u64::from_le_bytes(params.data[0x070..0x078].try_into().unwrap()); + assert_eq!(addr, 0xE0000); + } + + #[test] + fn test_boot_params_acpi_rsdp_addr_default_zero() { + let params = BootParams::new(); + let addr = u64::from_le_bytes(params.data[0x070..0x078].try_into().unwrap()); + assert_eq!(addr, 0, "acpi_rsdp_addr should be 0 by default"); + } +} diff --git a/src/vmm/src/windows/boot/setup.rs b/src/vmm/src/windows/boot/setup.rs new file mode 100644 index 000000000..444e943c9 --- /dev/null +++ b/src/vmm/src/windows/boot/setup.rs @@ -0,0 +1,315 @@ +//! x86_64 boot setup — page tables, GDT, and vCPU register configuration +//! for the Windows WHPX backend. + +use super::super::types::{DescriptorTable, SegmentRegister, SpecialRegisters, StandardRegisters}; + +// Page table constants +const PAGE_PRESENT: u64 = 1 << 0; +const PAGE_WRITE: u64 = 1 << 1; +const PAGE_SIZE_2MB: u64 = 1 << 7; + +// Control register bits +const CR0_PE: u64 = 1 << 0; +const CR0_ET: u64 = 1 << 4; +const CR0_NE: u64 = 1 << 5; +const CR0_WP: u64 = 1 << 16; +const CR0_AM: u64 = 1 << 18; +const CR0_PG: u64 = 1 << 31; + +const CR4_PAE: u64 = 1 << 5; +const CR4_OSFXSR: u64 = 1 << 9; +const CR4_OSXMMEXCPT: u64 = 1 << 10; + +const EFER_LME: u64 = 1 << 8; +const EFER_LMA: u64 = 1 << 10; +const EFER_SCE: u64 = 1 << 0; + +// GDT entry access byte and flags +const GDT_CODE_ACCESS: u16 = 0xA09B; +const GDT_DATA_ACCESS: u16 = 0xC093; +const GDT_TSS_ACCESS: u16 = 0x808B; + +/// Memory addresses for page table structures. +const PML4_ADDR: u64 = 0x9000; +const PDPT_ADDR: u64 = 0xA000; +const PD_ADDR: u64 = 0xB000; +const GDT_ADDR: u64 = 0x500; +const BOOT_STACK: u64 = 0x8FF0; + +/// GDT entry indices +const GDT_NULL: usize = 0; +const GDT_CODE: usize = 1; +const GDT_DATA: usize = 2; +const GDT_TSS: usize = 3; + +/// Number of GDT entries (null + code + data + TSS = 4) +const GDT_ENTRY_COUNT: usize = 4; + +/// Build identity-mapped page tables for 4GB. +pub fn build_page_tables() -> PageTables { + let mut pml4 = [0u64; 512]; + let mut pdpt = [0u64; 512]; + let mut pd = [[0u64; 512]; 4]; + + pml4[0] = PDPT_ADDR | PAGE_PRESENT | PAGE_WRITE; + + for (i, entry) in pdpt.iter_mut().enumerate().take(4) { + *entry = (PD_ADDR + i as u64 * 0x1000) | PAGE_PRESENT | PAGE_WRITE; + } + + for (i, pd_table) in pd.iter_mut().enumerate() { + for (j, entry) in pd_table.iter_mut().enumerate() { + let phys_addr = (i as u64 * 512 + j as u64) * (2 * 1024 * 1024); + *entry = phys_addr | PAGE_PRESENT | PAGE_WRITE | PAGE_SIZE_2MB; + } + } + + PageTables { pml4, pdpt, pd } +} + +/// Page table data ready to be written to guest memory. +pub struct PageTables { + pub pml4: [u64; 512], + pub pdpt: [u64; 512], + pub pd: [[u64; 512]; 4], +} + +impl PageTables { + pub fn pml4_bytes(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.pml4.as_ptr() as *const u8, 512 * 8) } + } + + pub fn pdpt_bytes(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.pdpt.as_ptr() as *const u8, 512 * 8) } + } + + pub fn pd_bytes(&self, index: usize) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.pd[index].as_ptr() as *const u8, 512 * 8) } + } +} + +/// Build the GDT entries. +pub fn build_gdt() -> Vec { + let mut gdt = vec![0u64; GDT_ENTRY_COUNT + 1]; + + gdt[GDT_NULL] = 0; + gdt[GDT_CODE] = gdt_entry(0, 0xFFFFF, GDT_CODE_ACCESS); + gdt[GDT_DATA] = gdt_entry(0, 0xFFFFF, GDT_DATA_ACCESS); + gdt[GDT_TSS] = gdt_entry(0, 0xFFFF, GDT_TSS_ACCESS); + gdt[GDT_TSS + 1] = 0; + + gdt +} + +fn gdt_entry(base: u32, limit: u32, access_rights: u16) -> u64 { + let access = (access_rights & 0xFF) as u64; + let flags = ((access_rights >> 8) & 0xF0) as u64; + let limit_low = (limit & 0xFFFF) as u64; + let limit_high = ((limit >> 16) & 0xF) as u64; + let base_low = (base & 0xFFFF) as u64; + let base_mid = ((base >> 16) & 0xFF) as u64; + let base_high = ((base >> 24) & 0xFF) as u64; + + limit_low + | (base_low << 16) + | (base_mid << 32) + | (access << 40) + | (limit_high << 48) + | (flags << 48) + | (base_high << 56) +} + +/// GDT data as bytes. +pub fn gdt_bytes(gdt: &[u64]) -> Vec { + let mut bytes = Vec::with_capacity(gdt.len() * 8); + for entry in gdt { + bytes.extend_from_slice(&entry.to_le_bytes()); + } + bytes +} + +/// Configure the initial vCPU registers for 64-bit long mode boot. +pub fn configure_boot_registers(kernel_entry: u64) -> (StandardRegisters, SpecialRegisters) { + let regs = StandardRegisters { + rip: kernel_entry, + rsp: BOOT_STACK, + rsi: super::super::memory::ZERO_PAGE_START, + rflags: 0x2, + ..Default::default() + }; + + let gdt_size = (GDT_ENTRY_COUNT + 1) * 8; + + let sregs = SpecialRegisters { + cs: SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x08, + access_rights: GDT_CODE_ACCESS, + }, + ds: SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x10, + access_rights: GDT_DATA_ACCESS, + }, + es: SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x10, + access_rights: GDT_DATA_ACCESS, + }, + fs: SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x10, + access_rights: GDT_DATA_ACCESS, + }, + gs: SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x10, + access_rights: GDT_DATA_ACCESS, + }, + ss: SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x10, + access_rights: GDT_DATA_ACCESS, + }, + tr: SegmentRegister { + base: 0, + limit: 0xFFFF, + selector: 0x18, + access_rights: GDT_TSS_ACCESS, + }, + ldt: SegmentRegister::default(), + gdt: DescriptorTable { + base: GDT_ADDR, + limit: (gdt_size - 1) as u16, + }, + idt: DescriptorTable { + base: 0, + limit: 0xFFFF, + }, + cr0: CR0_PE | CR0_ET | CR0_NE | CR0_WP | CR0_AM | CR0_PG, + cr2: 0, + cr3: PML4_ADDR, + cr4: CR4_PAE | CR4_OSFXSR | CR4_OSXMMEXCPT, + efer: EFER_LME | EFER_LMA | EFER_SCE, + }; + + (regs, sregs) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_page_tables_pml4_points_to_pdpt() { + let pt = build_page_tables(); + let entry = pt.pml4[0]; + assert_eq!(entry & !0xFFF, PDPT_ADDR); + assert_ne!(entry & PAGE_PRESENT, 0); + assert_ne!(entry & PAGE_WRITE, 0); + } + + #[test] + fn test_page_tables_pdpt_entries() { + let pt = build_page_tables(); + for i in 0..4 { + let entry = pt.pdpt[i]; + let expected_addr = PD_ADDR + i as u64 * 0x1000; + assert_eq!(entry & !0xFFF, expected_addr); + assert_ne!(entry & PAGE_PRESENT, 0); + } + for i in 4..512 { + assert_eq!(pt.pdpt[i], 0, "PDPT[{}] should be empty", i); + } + } + + #[test] + fn test_page_tables_identity_map() { + let pt = build_page_tables(); + for i in 0..4 { + for j in 0..512 { + let entry = pt.pd[i][j]; + let expected_phys = (i as u64 * 512 + j as u64) * 2 * 1024 * 1024; + assert_eq!(entry & !0xFFF, expected_phys); + assert_ne!(entry & PAGE_PRESENT, 0); + assert_ne!(entry & PAGE_SIZE_2MB, 0); + } + } + } + + #[test] + fn test_page_tables_cover_4gb() { + let pt = build_page_tables(); + let last_entry = pt.pd[3][511]; + let last_addr = last_entry & !0xFFF; + let expected = (4u64 * 1024 * 1024 * 1024) - (2 * 1024 * 1024); + assert_eq!(last_addr, expected); + } + + #[test] + fn test_gdt_has_null_entry() { + let gdt = build_gdt(); + assert_eq!(gdt[GDT_NULL], 0); + } + + #[test] + fn test_gdt_code_segment() { + let gdt = build_gdt(); + assert_ne!(gdt[GDT_CODE], 0); + } + + #[test] + fn test_gdt_data_segment() { + let gdt = build_gdt(); + assert_ne!(gdt[GDT_DATA], 0); + } + + #[test] + fn test_gdt_bytes_length() { + let gdt = build_gdt(); + let bytes = gdt_bytes(&gdt); + assert_eq!(bytes.len(), gdt.len() * 8); + } + + #[test] + fn test_boot_registers_long_mode() { + let (regs, sregs) = configure_boot_registers(0x100000); + + assert_eq!(regs.rip, 0x100000); + assert_eq!(regs.rsp, BOOT_STACK); + assert_eq!(regs.rsi, super::super::memory::ZERO_PAGE_START); + assert_ne!(regs.rflags & 0x2, 0); + assert_ne!(sregs.cr0 & CR0_PE, 0); + assert_ne!(sregs.cr0 & CR0_PG, 0); + assert_eq!(sregs.cr3, PML4_ADDR); + assert_ne!(sregs.cr4 & CR4_PAE, 0); + assert_ne!(sregs.efer & EFER_LME, 0); + assert_ne!(sregs.efer & EFER_LMA, 0); + } + + #[test] + fn test_boot_registers_segment_selectors() { + let (_, sregs) = configure_boot_registers(0x100000); + + assert_eq!(sregs.cs.selector, 0x08); + assert_eq!(sregs.ds.selector, 0x10); + assert_eq!(sregs.es.selector, 0x10); + assert_eq!(sregs.ss.selector, 0x10); + assert_eq!(sregs.tr.selector, 0x18); + } + + #[test] + fn test_gdt_entry_encoding() { + let entry = gdt_entry(0, 0xFFFFF, GDT_CODE_ACCESS); + assert_ne!(entry, 0); + + let null = gdt_entry(0, 0, 0); + assert_eq!(null, 0); + } +} diff --git a/src/vmm/src/windows/cmdline.rs b/src/vmm/src/windows/cmdline.rs new file mode 100644 index 000000000..bac75f617 --- /dev/null +++ b/src/vmm/src/windows/cmdline.rs @@ -0,0 +1,467 @@ +//! Kernel command line builder for the Windows WHPX backend. + +use super::memory::VIRTIO_MMIO_BASE; + +/// Size of each virtio-MMIO device slot in bytes. +pub const MMIO_SLOT_SIZE: u64 = 0x200; + +/// IRQ number for the first MMIO device slot. +pub const FIRST_MMIO_IRQ: u8 = 5; + +/// Base kernel command line parameters (quiet mode — fast boot). +/// +/// Quiet mode suppresses serial console output and i8042 keyboard probing, +/// eliminating ~36K VM exits per boot (~26K serial + ~10K i8042). This reduces +/// WHPX boot time from ~5s to ~1-2s. +/// +/// - `console=ttyS0`: Route kernel console to serial port (required — no VGA). +/// - `quiet loglevel=1`: Suppress kernel printk to console. +/// - `i8042.nokbd i8042.noaux`: Skip PS/2 keyboard/mouse probe (10K+ exits). +/// - `nohyperv`: Disable Hyper-V guest enlightenments. WHPX exposes Hyper-V +/// CPUID leaves but doesn't fully support synthetic timers/SynIC, causing +/// clock stalls if the kernel tries to use them. +/// - `lpj=1000000`: Preset loops_per_jiffy to skip delay calibration, which +/// depends on a reliable timer source. +/// - `nokaslr`: Disable kernel address space randomization for deterministic +/// boot in our controlled environment. +/// +/// Note: `noapic` and `nolapic` are NOT present — the MADT table in ACPI +/// tells the kernel about the IOAPIC and LAPIC for APIC-mode interrupt routing. +/// Note: `nosmp` is NOT present — multi-vCPU is supported via MADT LAPIC entries. +const BASE_CMDLINE: &str = + "console=ttyS0 quiet loglevel=1 i8042.nokbd i8042.noaux nohyperv lpj=1000000 nokaslr"; + +/// Serial console parameters appended in verbose mode. +/// +/// Enables full kernel boot output on the serial console. Useful for debugging +/// but adds ~26K VM exits (~3s) due to per-byte serial I/O port access. +const VERBOSE_CONSOLE: &str = "console=ttyS0 earlyprintk=serial,ttyS0,115200"; + +/// Description of a virtio-MMIO device slot for command line generation. +#[derive(Debug, Clone)] +pub struct MmioSlot { + /// Slot index (0-based). Determines MMIO base address and IRQ. + pub index: u8, + /// Whether the slot is active (has a device). + pub active: bool, +} + +/// Build the full kernel command line. +/// +/// Parameters: +/// - `user_cmdline`: Extra kernel parameters appended after device config. +/// - `has_root_disk`: Whether a root disk is attached (default `/dev/vda`). +/// - `mmio_slots`: Virtio-MMIO device slots to register. +/// - `root_disk_device`: Override root device (e.g., "/dev/vdb"). Takes priority over `has_root_disk`. +/// - `root_disk_fstype`: Filesystem type for root device (e.g., "ext4"). +/// - `exec_path`: Path to init binary (added as `init=`). +/// - `exec_argv`: Arguments passed after `--` separator for the init process. +/// - `verbose`: Enable serial console output. Adds `console=ttyS0` and removes +/// `quiet`/`i8042.nokbd` for full kernel boot logging. Slower (~5s vs ~1-2s). +pub fn build_kernel_cmdline( + user_cmdline: Option<&str>, + has_root_disk: bool, + mmio_slots: &[MmioSlot], + root_disk_device: Option<&str>, + root_disk_fstype: Option<&str>, + exec_path: Option<&str>, + exec_argv: &[String], + verbose: bool, +) -> String { + let mut cmdline = if verbose { + // Verbose mode: serial console + full i8042 probe for debugging. + // No noapic/nolapic — APIC mode is enabled via MADT (same as quiet mode). + format!("{} nohyperv lpj=1000000 nokaslr", VERBOSE_CONSOLE) + } else { + BASE_CMDLINE.to_string() + }; + + // Root device: explicit override takes priority over default. + if let Some(device) = root_disk_device { + cmdline.push_str(&format!(" root={}", device)); + if let Some(fstype) = root_disk_fstype { + cmdline.push_str(&format!(" rootfstype={}", fstype)); + } + cmdline.push_str(" rw"); + } else if has_root_disk { + cmdline.push_str(" root=/dev/vda rw"); + } + + // Init binary path. + if let Some(path) = exec_path { + cmdline.push_str(&format!(" init={}", path)); + } + + for slot in mmio_slots { + if !slot.active { + continue; + } + let base = VIRTIO_MMIO_BASE + (slot.index as u64) * MMIO_SLOT_SIZE; + let irq = FIRST_MMIO_IRQ + slot.index; + cmdline.push_str(&format!( + " virtio_mmio.device={}@0x{:x}:{}", + MMIO_SLOT_SIZE, base, irq + )); + } + + if let Some(extra) = user_cmdline { + if !extra.is_empty() { + cmdline.push(' '); + cmdline.push_str(extra); + } + } + + // Init arguments after separator. + if !exec_argv.is_empty() { + cmdline.push_str(" -- "); + cmdline.push_str(&exec_argv.join(" ")); + } + + cmdline +} + +/// Calculate the MMIO base address for a given slot index. +pub fn mmio_base_for_slot(index: u8) -> u64 { + VIRTIO_MMIO_BASE + (index as u64) * MMIO_SLOT_SIZE +} + +/// Calculate the IRQ number for a given slot index. +pub fn irq_for_slot(index: u8) -> u8 { + FIRST_MMIO_IRQ + index +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: build cmdline with only the legacy params (no root override, no init). + fn build_simple(user: Option<&str>, has_root: bool, slots: &[MmioSlot]) -> String { + build_kernel_cmdline(user, has_root, slots, None, None, None, &[], false) + } + + #[test] + fn test_base_cmdline_only() { + let cmdline = build_simple(None, false, &[]); + assert_eq!(cmdline, BASE_CMDLINE); + } + + #[test] + fn test_quiet_mode_default() { + let cmdline = build_simple(None, false, &[]); + assert!(cmdline.contains("console=ttyS0")); + assert!(cmdline.contains("quiet")); + assert!(cmdline.contains("loglevel=1")); + assert!(cmdline.contains("i8042.nokbd")); + assert!(cmdline.contains("i8042.noaux")); + assert!(!cmdline.contains("earlyprintk")); + } + + #[test] + fn test_verbose_mode() { + let cmdline = build_kernel_cmdline(None, false, &[], None, None, None, &[], true); + assert!(cmdline.contains("console=ttyS0")); + assert!(cmdline.contains("earlyprintk=serial,ttyS0,115200")); + assert!(!cmdline.contains("quiet")); + assert!(!cmdline.contains("loglevel=1")); + assert!(!cmdline.contains("i8042.nokbd")); + // Common params present in both modes. + assert!(cmdline.contains("nohyperv")); + assert!(cmdline.contains("lpj=1000000")); + assert!(cmdline.contains("nokaslr")); + } + + #[test] + fn test_with_root_disk() { + let cmdline = build_simple(None, true, &[]); + assert!(cmdline.contains("root=/dev/vda rw")); + assert!(cmdline.starts_with(BASE_CMDLINE)); + } + + #[test] + fn test_with_mmio_slots() { + let slots = vec![ + MmioSlot { + index: 0, + active: true, + }, + MmioSlot { + index: 1, + active: true, + }, + ]; + let cmdline = build_simple(None, true, &slots); + assert!(cmdline.contains("virtio_mmio.device=512@0xd0000000:5")); + assert!(cmdline.contains("virtio_mmio.device=512@0xd0000200:6")); + } + + #[test] + fn test_inactive_slots_skipped() { + let slots = vec![ + MmioSlot { + index: 0, + active: true, + }, + MmioSlot { + index: 1, + active: false, + }, + MmioSlot { + index: 2, + active: true, + }, + ]; + let cmdline = build_simple(None, false, &slots); + assert!(cmdline.contains("virtio_mmio.device=512@0xd0000000:5")); + assert!(!cmdline.contains("0xd0000200")); + assert!(cmdline.contains("virtio_mmio.device=512@0xd0000400:7")); + } + + #[test] + fn test_user_cmdline_appended() { + let cmdline = build_simple(Some("custom_param=1"), false, &[]); + assert!(cmdline.ends_with("custom_param=1")); + } + + #[test] + fn test_empty_user_cmdline_no_trailing_space() { + let cmdline = build_simple(Some(""), false, &[]); + assert!(!cmdline.ends_with(' ')); + assert_eq!(cmdline, BASE_CMDLINE); + } + + #[test] + fn test_base_cmdline_has_nohyperv() { + let cmdline = build_simple(None, false, &[]); + assert!(cmdline.contains("nohyperv")); + assert!(cmdline.contains("lpj=1000000")); + assert!(cmdline.contains("nokaslr")); + // noacpi must NOT be present (ACPI tables are provided). + assert!(!cmdline.contains("noacpi")); + } + + #[test] + fn test_cmdline_no_nosmp() { + // nosmp must NOT be present — multi-vCPU is supported via MADT LAPIC entries. + let quiet = build_simple(None, false, &[]); + assert!( + !quiet.contains("nosmp"), + "quiet cmdline must not contain nosmp (multi-vCPU enabled)" + ); + + let verbose = build_kernel_cmdline(None, false, &[], None, None, None, &[], true); + assert!( + !verbose.contains("nosmp"), + "verbose cmdline must not contain nosmp (multi-vCPU enabled)" + ); + } + + #[test] + fn test_cmdline_no_noacpi_no_noapic() { + // Verify neither quiet nor verbose mode includes noacpi or noapic. + // APIC mode is enabled via MADT; noapic/nolapic would disable it. + let quiet = build_simple(None, false, &[]); + assert!( + !quiet.contains("noacpi"), + "quiet cmdline must not contain noacpi" + ); + assert!( + !quiet.contains("noapic"), + "quiet cmdline must not contain noapic (APIC enabled via MADT)" + ); + assert!( + !quiet.contains("nolapic"), + "quiet cmdline must not contain nolapic" + ); + + let verbose = build_kernel_cmdline(None, false, &[], None, None, None, &[], true); + assert!( + !verbose.contains("noacpi"), + "verbose cmdline must not contain noacpi" + ); + assert!( + !verbose.contains("noapic"), + "verbose cmdline must not contain noapic (APIC enabled via MADT)" + ); + assert!( + !verbose.contains("nolapic"), + "verbose cmdline must not contain nolapic" + ); + } + + #[test] + fn test_mmio_base_for_slot() { + assert_eq!(mmio_base_for_slot(0), 0xD000_0000); + assert_eq!(mmio_base_for_slot(1), 0xD000_0200); + assert_eq!(mmio_base_for_slot(2), 0xD000_0400); + } + + #[test] + fn test_irq_for_slot() { + assert_eq!(irq_for_slot(0), 5); + assert_eq!(irq_for_slot(1), 6); + assert_eq!(irq_for_slot(2), 7); + } + + #[test] + fn test_full_cmdline_with_all_options() { + let slots = vec![ + MmioSlot { + index: 0, + active: true, + }, + MmioSlot { + index: 1, + active: true, + }, + MmioSlot { + index: 2, + active: true, + }, + ]; + let cmdline = build_simple(Some("custom_test=1"), true, &slots); + + let base_pos = cmdline.find(BASE_CMDLINE).unwrap(); + let root_pos = cmdline.find("root=/dev/vda").unwrap(); + let mmio0_pos = cmdline.find("0xd0000000:5").unwrap(); + let mmio1_pos = cmdline.find("0xd0000200:6").unwrap(); + let mmio2_pos = cmdline.find("0xd0000400:7").unwrap(); + let user_pos = cmdline.find("custom_test=1").unwrap(); + + assert!(base_pos < root_pos); + assert!(root_pos < mmio0_pos); + assert!(mmio0_pos < mmio1_pos); + assert!(mmio1_pos < mmio2_pos); + assert!(mmio2_pos < user_pos); + } + + // ---- New tests for root_disk_device, exec_path, exec_argv ---- + + #[test] + fn test_root_disk_device_override() { + let cmdline = build_kernel_cmdline( + None, + false, + &[], + Some("/dev/vdb"), + Some("ext4"), + None, + &[], + false, + ); + assert!(cmdline.contains("root=/dev/vdb")); + assert!(cmdline.contains("rootfstype=ext4")); + assert!(cmdline.contains("rw")); + assert!(!cmdline.contains("/dev/vda")); + } + + #[test] + fn test_root_disk_overrides_default() { + // When both has_root_disk=true and root_disk_device is set, + // the explicit device takes priority. + let cmdline = build_kernel_cmdline( + None, + true, + &[], + Some("/dev/vdb"), + Some("ext4"), + None, + &[], + false, + ); + assert!(cmdline.contains("root=/dev/vdb")); + assert!(!cmdline.contains("root=/dev/vda")); + } + + #[test] + fn test_root_disk_device_without_fstype() { + let cmdline = + build_kernel_cmdline(None, false, &[], Some("/dev/vdb"), None, None, &[], false); + assert!(cmdline.contains("root=/dev/vdb")); + assert!(!cmdline.contains("rootfstype=")); + assert!(cmdline.contains("rw")); + } + + #[test] + fn test_init_path() { + let cmdline = build_kernel_cmdline( + None, + false, + &[], + None, + None, + Some("/boxlite/bin/boxlite-guest"), + &[], + false, + ); + assert!(cmdline.contains("init=/boxlite/bin/boxlite-guest")); + } + + #[test] + fn test_init_args_after_separator() { + let argv = vec![ + "--listen".to_string(), + "vsock://2695".to_string(), + "--notify".to_string(), + "vsock://2696".to_string(), + ]; + let cmdline = build_kernel_cmdline( + None, + false, + &[], + None, + None, + Some("/boxlite/bin/boxlite-guest"), + &argv, + false, + ); + assert!(cmdline.contains("init=/boxlite/bin/boxlite-guest")); + assert!(cmdline.ends_with("-- --listen vsock://2695 --notify vsock://2696")); + } + + #[test] + fn test_no_separator_when_argv_empty() { + let cmdline = + build_kernel_cmdline(None, false, &[], None, None, Some("/bin/init"), &[], false); + assert!(cmdline.contains("init=/bin/init")); + assert!(!cmdline.contains("--")); + } + + #[test] + fn test_full_lifecycle_cmdline() { + // Simulates the full box lifecycle cmdline: + // root=/dev/vdb rootfstype=ext4 rw init=/boxlite/bin/boxlite-guest + // virtio_mmio devices, then -- + let slots = vec![ + MmioSlot { + index: 0, + active: true, + }, + MmioSlot { + index: 1, + active: true, + }, + ]; + let argv = vec!["--listen".to_string(), "vsock://2695".to_string()]; + let cmdline = build_kernel_cmdline( + None, + true, + &slots, + Some("/dev/vdb"), + Some("ext4"), + Some("/boxlite/bin/boxlite-guest"), + &argv, + false, + ); + + // Verify ordering: base < root < init < mmio < argv + let root_pos = cmdline.find("root=/dev/vdb").unwrap(); + let init_pos = cmdline.find("init=/boxlite/bin/boxlite-guest").unwrap(); + let mmio_pos = cmdline.find("virtio_mmio").unwrap(); + let sep_pos = cmdline.find("-- --listen").unwrap(); + + assert!(root_pos < init_pos); + assert!(init_pos < mmio_pos); + assert!(mmio_pos < sep_pos); + assert!(!cmdline.contains("root=/dev/vda")); + } +} diff --git a/src/vmm/src/windows/context.rs b/src/vmm/src/windows/context.rs new file mode 100644 index 000000000..de0e4ac1f --- /dev/null +++ b/src/vmm/src/windows/context.rs @@ -0,0 +1,320 @@ +//! VM context — configuration state machine for building a VM. +//! +//! Mirrors libkrun's KrunContext pattern: create → configure → start. + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Mutex; + +use super::error::{Result, WkrunError}; +use super::types::VmState; + +/// Global context ID counter. +static NEXT_CTX_ID: AtomicU32 = AtomicU32::new(0); + +/// Global context map — maps context IDs to VM configurations. +/// Uses a Mutex for thread-safe access from the C API. +static CTX_MAP: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Disk format constants (matching libkrun). +pub const DISK_FORMAT_RAW: u32 = 0; +pub const DISK_FORMAT_QCOW2: u32 = 1; + +/// Configuration for a virtual machine. +pub struct VmContext { + /// Context ID. + pub id: u32, + /// Current state. + pub state: VmState, + /// Number of vCPUs. + pub num_vcpus: u8, + /// RAM size in MiB. + pub ram_mib: u32, + /// Root filesystem path. + pub root_path: Option, + /// Kernel image path (for direct boot). + pub kernel_path: Option, + /// Kernel command line. + pub kernel_cmdline: Option, + /// Initramfs path. + pub initramfs_path: Option, + /// Executable path for the guest init. + pub exec_path: Option, + /// Arguments for the guest executable. + pub argv: Vec, + /// Environment variables for the guest. + pub envp: Vec, + /// Working directory in the guest. + pub workdir: Option, + /// Attached block devices. + pub disks: Vec, + /// Virtiofs/9p mounts. + pub fs_mounts: Vec, + /// Vsock port bridges. + pub vsock_ports: Vec, + /// Console output file path. + pub console_output: Option, + /// Resource limits to apply in the guest (format: "RESOURCE=CUR:MAX"). + pub rlimits: Vec, + /// Whether APIC emulation is enabled. + pub apic_emulation: bool, + /// Network device configuration. + pub net_config: Option, + /// Root disk device path override (e.g., "/dev/vdb"). + /// When set, the kernel cmdline uses this instead of the default "/dev/vda". + pub root_disk_device: Option, + /// Root disk filesystem type (e.g., "ext4"). + pub root_disk_fstype: Option, + /// Enable verbose serial console output for debugging. + /// + /// When true, the kernel cmdline includes `console=ttyS0` for full boot + /// logging. When false (default), quiet mode suppresses serial output and + /// i8042 probing for faster boot (~1-2s vs ~5s). + pub verbose: bool, +} + +/// Network device configuration. +pub struct NetConfig { + /// MAC address (6 bytes). If unset, auto-generated. + pub mac: [u8; 6], + /// Path to the userspace networking proxy socket. + pub socket_path: PathBuf, +} + +/// Block device configuration. +pub struct DiskConfig { + pub block_id: String, + pub path: PathBuf, + pub format: u32, + pub read_only: bool, +} + +/// Filesystem mount configuration (virtiofs or 9p). +pub struct FsMount { + pub tag: String, + pub host_path: PathBuf, +} + +/// Vsock port bridge configuration. +pub struct VsockPort { + pub port: u32, + pub host_path: PathBuf, + pub listen: bool, + /// Optional host TCP port override. When set, the vsock bridge listens on + /// this TCP port instead of the vsock port number. Enables multiple VMs + /// to use distinct host ports for the same guest vsock port. + pub host_tcp_port: Option, +} + +impl VmContext { + fn new(id: u32) -> Self { + VmContext { + id, + state: VmState::Created, + num_vcpus: 1, + ram_mib: 256, + root_path: None, + kernel_path: None, + kernel_cmdline: None, + initramfs_path: None, + exec_path: None, + argv: Vec::new(), + envp: Vec::new(), + workdir: None, + disks: Vec::new(), + fs_mounts: Vec::new(), + vsock_ports: Vec::new(), + console_output: None, + rlimits: Vec::new(), + apic_emulation: true, + net_config: None, + root_disk_device: None, + root_disk_fstype: None, + verbose: false, + } + } + + /// Create a VmContext with default values for testing. + #[cfg(test)] + pub fn default_for_test() -> Self { + Self::new(0) + } +} + +/// Create a new VM context. Returns the context ID (>= 0) on success. +pub fn create_ctx() -> Result { + let id = NEXT_CTX_ID.fetch_add(1, Ordering::Relaxed); + let ctx = VmContext::new(id); + + let mut map = CTX_MAP + .lock() + .map_err(|_| WkrunError::Config("context map lock poisoned".into()))?; + + if map.contains_key(&id) { + return Err(WkrunError::ContextExists(id)); + } + + map.insert(id, ctx); + Ok(id) +} + +/// Free (destroy) a VM context. Returns Ok(()) on success. +pub fn free_ctx(ctx_id: u32) -> Result<()> { + let mut map = CTX_MAP + .lock() + .map_err(|_| WkrunError::Config("context map lock poisoned".into()))?; + + map.remove(&ctx_id) + .ok_or(WkrunError::InvalidContext(ctx_id))?; + + Ok(()) +} + +/// Execute a closure with mutable access to a VM context. +pub fn with_ctx_mut(ctx_id: u32, f: F) -> Result +where + F: FnOnce(&mut VmContext) -> Result, +{ + let mut map = CTX_MAP + .lock() + .map_err(|_| WkrunError::Config("context map lock poisoned".into()))?; + + let ctx = map + .get_mut(&ctx_id) + .ok_or(WkrunError::InvalidContext(ctx_id))?; + + f(ctx) +} + +/// Execute a closure with read access to a VM context. +pub fn with_ctx(ctx_id: u32, f: F) -> Result +where + F: FnOnce(&VmContext) -> Result, +{ + let map = CTX_MAP + .lock() + .map_err(|_| WkrunError::Config("context map lock poisoned".into()))?; + + let ctx = map.get(&ctx_id).ok_or(WkrunError::InvalidContext(ctx_id))?; + + f(ctx) +} + +/// Take (remove) a VM context from the global map. +/// Used when starting the VM — the context is consumed. +pub fn take_ctx(ctx_id: u32) -> Result { + let mut map = CTX_MAP + .lock() + .map_err(|_| WkrunError::Config("context map lock poisoned".into()))?; + + map.remove(&ctx_id) + .ok_or(WkrunError::InvalidContext(ctx_id)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_and_free_ctx() { + let id = create_ctx().unwrap(); + assert!(free_ctx(id).is_ok()); + } + + #[test] + fn test_double_free_returns_error() { + let id = create_ctx().unwrap(); + assert!(free_ctx(id).is_ok()); + assert!(free_ctx(id).is_err()); + } + + #[test] + fn test_invalid_ctx_returns_error() { + assert!(free_ctx(u32::MAX).is_err()); + } + + #[test] + fn test_with_ctx_mut() { + let id = create_ctx().unwrap(); + + with_ctx_mut(id, |ctx| { + ctx.num_vcpus = 4; + ctx.ram_mib = 1024; + Ok(()) + }) + .unwrap(); + + with_ctx(id, |ctx| { + assert_eq!(ctx.num_vcpus, 4); + assert_eq!(ctx.ram_mib, 1024); + Ok(()) + }) + .unwrap(); + + free_ctx(id).unwrap(); + } + + #[test] + fn test_take_ctx() { + let id = create_ctx().unwrap(); + + with_ctx_mut(id, |ctx| { + ctx.ram_mib = 512; + Ok(()) + }) + .unwrap(); + + let ctx = take_ctx(id).unwrap(); + assert_eq!(ctx.ram_mib, 512); + + // After taking, the context should no longer exist + assert!(free_ctx(id).is_err()); + } + + #[test] + fn test_set_rlimits() { + let id = create_ctx().unwrap(); + + with_ctx_mut(id, |ctx| { + ctx.rlimits = vec![ + "RLIMIT_NOFILE=1024:4096".to_string(), + "RLIMIT_NPROC=512:1024".to_string(), + ]; + Ok(()) + }) + .unwrap(); + + with_ctx(id, |ctx| { + assert_eq!(ctx.rlimits.len(), 2); + assert_eq!(ctx.rlimits[0], "RLIMIT_NOFILE=1024:4096"); + assert_eq!(ctx.rlimits[1], "RLIMIT_NPROC=512:1024"); + Ok(()) + }) + .unwrap(); + + free_ctx(id).unwrap(); + } + + #[test] + fn test_context_defaults() { + let id = create_ctx().unwrap(); + + with_ctx(id, |ctx| { + assert_eq!(ctx.num_vcpus, 1); + assert_eq!(ctx.ram_mib, 256); + assert_eq!(ctx.state, VmState::Created); + assert!(ctx.root_path.is_none()); + assert!(ctx.kernel_path.is_none()); + assert!(ctx.disks.is_empty()); + assert!(ctx.fs_mounts.is_empty()); + assert!(ctx.vsock_ports.is_empty()); + assert!(ctx.rlimits.is_empty()); + Ok(()) + }) + .unwrap(); + + free_ctx(id).unwrap(); + } +} diff --git a/src/vmm/src/windows/devices/ioapic.rs b/src/vmm/src/windows/devices/ioapic.rs new file mode 100644 index 000000000..00fa4f5d3 --- /dev/null +++ b/src/vmm/src/windows/devices/ioapic.rs @@ -0,0 +1,507 @@ +//! I/O APIC (IOAPIC) emulation. +//! +//! Emulates a 24-pin IOAPIC with redirection table entries for routing +//! interrupts from devices to the Local APIC. +//! +//! MMIO interface at 0xFEC0_0000 (4KB region): +//! - Offset 0x00: IOREGSEL (write register index) +//! - Offset 0x10: IOWIN (read/write selected register) +//! +//! Registers: +//! - 0x00: IOAPIC ID +//! - 0x01: IOAPIC Version (24 entries, version 0x11) +//! - 0x10-0x3F: Redirection table entries (low/high 32 bits) + +/// Number of redirection table entries (pins). +const NUM_PINS: usize = 24; + +/// IOAPIC version register value. +/// Bits [7:0] = version (0x11 = 82093AA), bits [23:16] = max redirection entry (23). +const IOAPIC_VERSION: u32 = 0x0017_0011; + +/// A single redirection table entry. +/// +/// Each entry controls how an interrupt on the corresponding pin is delivered. +#[derive(Debug, Clone, Copy)] +struct RedirectionEntry { + /// IDT vector (0-255). + vector: u8, + /// Delivery mode: 0=Fixed, 2=SMI, 4=NMI, 5=INIT, 7=ExtINT. + delivery_mode: u8, + /// Destination mode: false=physical, true=logical. + dest_mode: bool, + /// Pin polarity: false=active-high, true=active-low. + polarity: bool, + /// Trigger mode: false=edge, true=level. + trigger_mode: bool, + /// true = masked (interrupt suppressed). + mask: bool, + /// Level-triggered: set on delivery, cleared on EOI. + remote_irr: bool, + /// LAPIC destination ID. + dest: u8, +} + +impl Default for RedirectionEntry { + fn default() -> Self { + Self { + vector: 0, + delivery_mode: 0, + dest_mode: false, + polarity: false, + trigger_mode: false, + mask: true, // Masked by default + remote_irr: false, + dest: 0, + } + } +} + +impl RedirectionEntry { + /// Read the low 32 bits of the redirection entry. + fn read_low(&self) -> u32 { + let mut val = self.vector as u32; + val |= (self.delivery_mode as u32 & 0x7) << 8; + if self.dest_mode { + val |= 1 << 11; + } + if self.polarity { + val |= 1 << 13; + } + if self.remote_irr { + val |= 1 << 14; + } + if self.trigger_mode { + val |= 1 << 15; + } + if self.mask { + val |= 1 << 16; + } + val + } + + /// Read the high 32 bits (destination field in bits [31:24]). + fn read_high(&self) -> u32 { + (self.dest as u32) << 24 + } + + /// Write the low 32 bits. + fn write_low(&mut self, val: u32) { + self.vector = (val & 0xFF) as u8; + self.delivery_mode = ((val >> 8) & 0x7) as u8; + self.dest_mode = val & (1 << 11) != 0; + self.polarity = val & (1 << 13) != 0; + // remote_irr is read-only (bit 14). + self.trigger_mode = val & (1 << 15) != 0; + self.mask = val & (1 << 16) != 0; + } + + /// Write the high 32 bits. + fn write_high(&mut self, val: u32) { + self.dest = ((val >> 24) & 0xFF) as u8; + } +} + +/// 24-pin I/O APIC. +pub struct IoApic { + /// IOAPIC ID (bits [27:24] of register 0x00). + id: u8, + /// IOREGSEL: indirect register select. + reg_sel: u8, + /// 24 redirection table entries. + entries: [RedirectionEntry; NUM_PINS], + /// Pin assertion state (for level-triggered re-injection). + pin_state: u32, +} + +impl Default for IoApic { + fn default() -> Self { + Self::new() + } +} + +impl IoApic { + /// Create a new IOAPIC with default state (all pins masked). + pub fn new() -> Self { + Self { + id: 0, + reg_sel: 0, + entries: [RedirectionEntry::default(); NUM_PINS], + pin_state: 0, + } + } + + /// Process an IRQ signal. Returns `(vector, dest_apic_id)` if the interrupt + /// is deliverable, or None if masked/blocked. + /// + /// - Edge-triggered: deliver if not masked, set pin state. + /// - Level-triggered: deliver if not masked AND remote_irr not set. + pub fn service_irq(&mut self, irq: u8, level: bool) -> Option<(u8, u8)> { + if irq as usize >= NUM_PINS { + return None; + } + + if level { + self.pin_state |= 1 << irq; + } else { + self.pin_state &= !(1 << irq); + return None; // Deassertion doesn't deliver. + } + + let entry = &mut self.entries[irq as usize]; + + if entry.mask { + return None; + } + + if entry.trigger_mode { + // Level-triggered: only deliver if remote_irr is not set. + if entry.remote_irr { + return None; + } + entry.remote_irr = true; + } + // Edge-triggered: always deliver (if not masked). + + Some((entry.vector, entry.dest)) + } + + /// Handle End-of-Interrupt for a given vector. + /// + /// Clears remote_irr for matching level-triggered entries. + /// Returns the pin number if still asserted (needs re-injection), or None. + pub fn end_of_interrupt(&mut self, vector: u8) -> Option { + for (i, entry) in self.entries.iter_mut().enumerate() { + if entry.vector == vector && entry.trigger_mode && entry.remote_irr { + entry.remote_irr = false; + // Check if pin is still asserted. + if self.pin_state & (1 << i) != 0 { + return Some(i as u8); + } + } + } + None + } + + /// Read from the IOAPIC MMIO region. + /// + /// Only offsets 0x00 (IOREGSEL) and 0x10 (IOWIN) are valid. + pub fn read_mmio(&self, offset: u64) -> u32 { + match offset { + 0x00 => self.reg_sel as u32, + 0x10 => self.read_register(self.reg_sel), + _ => 0, + } + } + + /// Write to the IOAPIC MMIO region. + pub fn write_mmio(&mut self, offset: u64, value: u32) { + match offset { + 0x00 => self.reg_sel = value as u8, + 0x10 => self.write_register(self.reg_sel, value), + _ => {} + } + } + + /// Read an indirect register by index. + fn read_register(&self, reg: u8) -> u32 { + match reg { + 0x00 => (self.id as u32) << 24, // IOAPIC ID + 0x01 => IOAPIC_VERSION, // Version + 0x02 => 0, // Arbitration ID (not used) + 0x10..=0x3F => { + let pin = ((reg - 0x10) / 2) as usize; + if pin < NUM_PINS { + if reg & 1 == 0 { + self.entries[pin].read_low() + } else { + self.entries[pin].read_high() + } + } else { + 0 + } + } + _ => 0, + } + } + + /// Check if any redirection table entry is unmasked (active). + pub fn has_unmasked_entries(&self) -> bool { + self.entries.iter().any(|e| !e.mask) + } + + /// Write an indirect register by index. + fn write_register(&mut self, reg: u8, value: u32) { + match reg { + 0x00 => self.id = ((value >> 24) & 0x0F) as u8, + 0x10..=0x3F => { + let pin = ((reg - 0x10) / 2) as usize; + if pin < NUM_PINS { + if reg & 1 == 0 { + self.entries[pin].write_low(value); + } else { + self.entries[pin].write_high(value); + } + } + } + _ => {} // Read-only or reserved registers. + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ioapic_initial_state() { + let ioapic = IoApic::new(); + assert_eq!(ioapic.id, 0); + assert_eq!(ioapic.reg_sel, 0); + // All entries should be masked. + for entry in &ioapic.entries { + assert!(entry.mask); + assert_eq!(entry.vector, 0); + } + assert!(!ioapic.has_unmasked_entries()); + } + + #[test] + fn test_ioapic_has_unmasked_entries() { + let mut ioapic = IoApic::new(); + assert!(!ioapic.has_unmasked_entries()); + + // Unmask pin 2 with vector 0x22. + ioapic.write_mmio(0x00, 0x14); // Select register 0x14 (pin 2 low) + ioapic.write_mmio(0x10, 0x22); // vector=0x22, mask bit 16 = 0 (unmasked) + assert!(ioapic.has_unmasked_entries()); + } + + #[test] + fn test_ioapic_version_register() { + let ioapic = IoApic::new(); + // Select version register. + let version = ioapic.read_register(0x01); + assert_eq!(version & 0xFF, 0x11, "version should be 0x11"); + assert_eq!((version >> 16) & 0xFF, 23, "max redir entry should be 23"); + } + + #[test] + fn test_ioapic_id_read_write() { + let mut ioapic = IoApic::new(); + ioapic.write_register(0x00, 0x0A00_0000); // Set ID = 0x0A + assert_eq!(ioapic.read_register(0x00), 0x0A00_0000); + assert_eq!(ioapic.id, 0x0A); + } + + #[test] + fn test_ioapic_redir_entry_read_write() { + let mut ioapic = IoApic::new(); + + // Write low 32 bits of entry 0 (register 0x10): + // vector=0x30, delivery_mode=0 (Fixed), level-triggered, unmasked + let low: u32 = 0x30 | (1 << 15); // vector=0x30, trigger=level, mask=0 + ioapic.write_register(0x10, low); + + // Write high 32 bits of entry 0 (register 0x11): + // destination = LAPIC 0 + ioapic.write_register(0x11, 0x00 << 24); + + let read_low = ioapic.read_register(0x10); + assert_eq!(read_low & 0xFF, 0x30, "vector"); + assert!(read_low & (1 << 15) != 0, "trigger mode should be level"); + assert!(read_low & (1 << 16) == 0, "should be unmasked"); + + let read_high = ioapic.read_register(0x11); + assert_eq!((read_high >> 24) & 0xFF, 0, "dest should be 0"); + } + + #[test] + fn test_ioapic_masked_irq_not_delivered() { + let mut ioapic = IoApic::new(); + // Entry 0 is masked by default. + assert_eq!(ioapic.service_irq(0, true), None); + } + + #[test] + fn test_ioapic_edge_triggered_delivery() { + let mut ioapic = IoApic::new(); + + // Configure pin 5: edge-triggered, vector 0x25, dest 0, unmasked. + ioapic.entries[5].vector = 0x25; + ioapic.entries[5].mask = false; + ioapic.entries[5].trigger_mode = false; // Edge + + let result = ioapic.service_irq(5, true); + assert_eq!(result, Some((0x25, 0))); + } + + #[test] + fn test_ioapic_level_triggered_delivery() { + let mut ioapic = IoApic::new(); + + // Configure pin 3: level-triggered, vector 0x33, dest 0, unmasked. + ioapic.entries[3].vector = 0x33; + ioapic.entries[3].mask = false; + ioapic.entries[3].trigger_mode = true; // Level + + let result = ioapic.service_irq(3, true); + assert_eq!(result, Some((0x33, 0))); + assert!(ioapic.entries[3].remote_irr, "remote_irr should be set"); + } + + #[test] + fn test_ioapic_level_triggered_blocked_by_remote_irr() { + let mut ioapic = IoApic::new(); + + // Configure pin 3: level-triggered, vector 0x33, unmasked. + ioapic.entries[3].vector = 0x33; + ioapic.entries[3].mask = false; + ioapic.entries[3].trigger_mode = true; + + // First delivery sets remote_irr. + assert_eq!(ioapic.service_irq(3, true), Some((0x33, 0))); + + // Second delivery blocked by remote_irr. + assert_eq!(ioapic.service_irq(3, true), None); + } + + #[test] + fn test_ioapic_eoi_clears_remote_irr() { + let mut ioapic = IoApic::new(); + + ioapic.entries[3].vector = 0x33; + ioapic.entries[3].mask = false; + ioapic.entries[3].trigger_mode = true; + + ioapic.service_irq(3, true); + assert!(ioapic.entries[3].remote_irr); + + // EOI should clear remote_irr and return the pin for re-injection. + let reinject_pin = ioapic.end_of_interrupt(0x33); + assert!(!ioapic.entries[3].remote_irr); + // Pin is still asserted, so re-injection needed on pin 3. + assert_eq!(reinject_pin, Some(3)); + } + + #[test] + fn test_ioapic_eoi_no_reinjection_when_deasserted() { + let mut ioapic = IoApic::new(); + + ioapic.entries[3].vector = 0x33; + ioapic.entries[3].mask = false; + ioapic.entries[3].trigger_mode = true; + + ioapic.service_irq(3, true); + // Deassert the pin. + ioapic.service_irq(3, false); + + let reinject_pin = ioapic.end_of_interrupt(0x33); + assert_eq!(reinject_pin, None, "no reinjection when pin is deasserted"); + } + + #[test] + fn test_ioapic_deassertion_does_not_deliver() { + let mut ioapic = IoApic::new(); + + ioapic.entries[5].vector = 0x25; + ioapic.entries[5].mask = false; + + // Deassertion (level=false) should not deliver. + assert_eq!(ioapic.service_irq(5, false), None); + } + + #[test] + fn test_ioapic_out_of_range_irq() { + let mut ioapic = IoApic::new(); + assert_eq!(ioapic.service_irq(24, true), None); + assert_eq!(ioapic.service_irq(255, true), None); + } + + #[test] + fn test_ioapic_mmio_regsel() { + let mut ioapic = IoApic::new(); + + // Write IOREGSEL. + ioapic.write_mmio(0x00, 0x01); + assert_eq!(ioapic.reg_sel, 0x01); + + // Read IOREGSEL. + assert_eq!(ioapic.read_mmio(0x00), 0x01); + } + + #[test] + fn test_ioapic_mmio_iowin_version() { + let mut ioapic = IoApic::new(); + + // Select version register. + ioapic.write_mmio(0x00, 0x01); + let version = ioapic.read_mmio(0x10); + assert_eq!(version & 0xFF, 0x11); + } + + #[test] + fn test_ioapic_mmio_invalid_offset() { + let mut ioapic = IoApic::new(); + // Invalid offsets should return 0 / be no-ops. + assert_eq!(ioapic.read_mmio(0x04), 0); + ioapic.write_mmio(0x04, 0xDEAD); + } + + #[test] + fn test_ioapic_redir_entry_remote_irr_readonly() { + let mut ioapic = IoApic::new(); + + // Set remote_irr manually. + ioapic.entries[0].remote_irr = true; + + // Write low word without remote_irr bit — it should NOT clear remote_irr. + let low = 0x30u32; // vector=0x30, no remote_irr bit set + ioapic.write_register(0x10, low); + + // remote_irr is read-only in the write path. + assert!(ioapic.entries[0].remote_irr); + } + + #[test] + fn test_ioapic_multiple_pins_independent() { + let mut ioapic = IoApic::new(); + + // Configure two different pins. + ioapic.entries[1].vector = 0x21; + ioapic.entries[1].mask = false; + ioapic.entries[2].vector = 0x22; + ioapic.entries[2].mask = false; + + assert_eq!(ioapic.service_irq(1, true), Some((0x21, 0))); + assert_eq!(ioapic.service_irq(2, true), Some((0x22, 0))); + } + + #[test] + fn test_ioapic_out_of_range_register() { + let ioapic = IoApic::new(); + // Registers beyond 0x3F should return 0. + assert_eq!(ioapic.read_register(0x40), 0); + assert_eq!(ioapic.read_register(0xFF), 0); + } + + #[test] + fn test_ioapic_service_irq_returns_destination() { + let mut ioapic = IoApic::new(); + + // Configure pin 4: vector 0x24, dest APIC ID = 1, unmasked. + ioapic.entries[4].vector = 0x24; + ioapic.entries[4].dest = 1; + ioapic.entries[4].mask = false; + + let result = ioapic.service_irq(4, true); + assert_eq!(result, Some((0x24, 1))); + } + + #[test] + fn test_ioapic_pin_beyond_24_in_redir() { + let ioapic = IoApic::new(); + // Register 0x10 + 24*2 = 0x40, which is pin 24 (out of range). + assert_eq!(ioapic.read_register(0x40), 0); + } +} diff --git a/src/vmm/src/windows/devices/irq_chip.rs b/src/vmm/src/windows/devices/irq_chip.rs new file mode 100644 index 000000000..b12e9c9e0 --- /dev/null +++ b/src/vmm/src/windows/devices/irq_chip.rs @@ -0,0 +1,714 @@ +//! IrqChip — coordinator wiring PIC + IOAPIC + LAPIC(s) together. +//! +//! Manages the interrupt routing between legacy PIC (for early boot before +//! APIC is enabled) and the IOAPIC + LAPIC path (after guest enables APIC). +//! +//! Supports multiple LAPICs for multi-vCPU configurations. Each vCPU has its +//! own LAPIC, indexed by vCPU ID. Device interrupts from the IOAPIC are routed +//! to the target LAPIC based on the redirection entry destination field. +//! +//! The APIC mode is auto-detected: when the guest writes to the LAPIC SVR +//! register with the enable bit set, the IrqChip switches to APIC mode. + +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use super::super::memory::{IOAPIC_MMIO_BASE, IOAPIC_MMIO_SIZE, LAPIC_MMIO_BASE, LAPIC_MMIO_SIZE}; +use super::ioapic::IoApic; +use super::lapic::{IpiAction, LocalApic, SharedApicState}; +use super::pic::Pic; + +/// Result of an IrqChip MMIO write operation. +#[derive(Debug)] +pub struct IrqChipWriteResult { + /// Whether the address was handled by the IrqChip. + pub handled: bool, + /// IPI action to dispatch (from LAPIC ICR write). + pub ipi_action: IpiAction, +} + +impl Default for IrqChipWriteResult { + fn default() -> Self { + Self { + handled: false, + ipi_action: IpiAction::None, + } + } +} + +/// Coordinated interrupt controller combining PIC, IOAPIC, and per-vCPU LAPICs. +pub struct IrqChip { + /// Legacy PIC (for early boot before APIC is enabled). + pub pic: Pic, + /// I/O APIC for routing device interrupts to the LAPICs. + ioapic: IoApic, + /// Per-vCPU Local APICs (indexed by vCPU ID). + /// + /// Each LAPIC is wrapped in its own Arc> to allow per-vCPU locking. + /// This eliminates cross-vCPU contention during LAPIC MMIO reads (esp. timer + /// CCR at 0x390), which is critical for 4+ vCPU support — without this, SMP + /// timer calibration causes BSP starvation on tick_and_poll(). + lapics: Vec>>, + /// Per-vCPU shared APIC state for lock-free cross-vCPU interrupt delivery. + /// + /// Source vCPUs atomically OR vector bits into the target's SharedApicState. + /// The owning vCPU pulls these into its local IRR via `pull_irr()`. + shared_states: Vec>, + /// false = PIC mode (early boot), true = APIC mode. + apic_mode: bool, +} + +impl Default for IrqChip { + fn default() -> Self { + Self::new(1) + } +} + +impl IrqChip { + /// Create a new IrqChip in PIC mode (legacy boot) with N LAPICs. + pub fn new(num_vcpus: u8) -> Self { + let lapics = (0..num_vcpus) + .map(|id| Arc::new(Mutex::new(LocalApic::new_with_id(id)))) + .collect(); + let shared_states = (0..num_vcpus) + .map(|_| Arc::new(SharedApicState::new())) + .collect(); + Self { + pic: Pic::new(), + ioapic: IoApic::new(), + lapics, + shared_states, + apic_mode: false, + } + } + + /// Get a clone of the Arc> for a specific vCPU. + /// + /// Used by the runner to acquire per-vCPU LAPIC refs that can be locked + /// independently of the DeviceManager lock (fast path for MMIO reads). + pub fn get_lapic_ref(&self, vcpu_id: u32) -> Arc> { + self.lapics[vcpu_id as usize].clone() + } + + /// Get a clone of the Arc for a specific vCPU. + /// + /// Used by the runner for lock-free cross-vCPU interrupt delivery. + pub fn get_shared_state(&self, vcpu_id: u32) -> Arc { + self.shared_states[vcpu_id as usize].clone() + } + + /// Number of vCPUs (LAPICs). + pub fn num_vcpus(&self) -> u8 { + self.lapics.len() as u8 + } + + /// Whether the chip is in APIC mode (vs legacy PIC mode). + pub fn apic_mode(&self) -> bool { + self.apic_mode + } + + /// Raise an interrupt on the given ISA IRQ line. + /// + /// Routes to IOAPIC (if APIC mode) or PIC (legacy mode). + /// In APIC mode, applies the standard x86 IRQ-to-GSI remapping: + /// ISA IRQ 0 (PIT timer) → IOAPIC pin 2 (GSI 2), matching the + /// Interrupt Source Override entry in the MADT. + pub fn raise_irq(&mut self, irq: u8) { + if self.apic_mode { + // Remap ISA IRQ to IOAPIC pin (GSI). + // Standard x86: PIT timer (IRQ 0) routes to IOAPIC pin 2. + let gsi = if irq == 0 { 2 } else { irq }; + if let Some((vector, dest)) = self.ioapic.service_irq(gsi, true) { + let target = (dest as usize).min(self.lapics.len() - 1); + // Lock-free: atomic OR into shared state instead of locking LAPIC. + self.shared_states[target].request_interrupt(vector); + } + } else { + self.pic.raise_irq(irq); + } + } + + /// Get the highest-priority injectable vector for a specific vCPU. + /// + /// Checks LAPIC (APIC mode) or PIC (legacy mode, only for BSP / vCPU 0). + pub fn get_injectable_vector(&self, vcpu_id: u8) -> Option { + if self.apic_mode { + self.lapics[vcpu_id as usize] + .lock() + .unwrap() + .get_highest_injectable() + } else if vcpu_id == 0 { + if self.pic.has_pending() { + // PIC has pending, but we need to peek — can't acknowledge yet. + // Return a sentinel to indicate "has pending". + Some(0) // Caller should use acknowledge() to get actual vector. + } else { + None + } + } else { + None // APs don't get PIC interrupts. + } + } + + /// Check if there are any pending interrupts for a specific vCPU. + pub fn has_pending(&self, vcpu_id: u8) -> bool { + if self.apic_mode { + self.lapics[vcpu_id as usize] + .lock() + .unwrap() + .get_highest_injectable() + .is_some() + } else if vcpu_id == 0 { + self.pic.has_pending() + } else { + false + } + } + + /// Acknowledge the highest-priority interrupt for a specific vCPU. + /// + /// In PIC mode (vCPU 0 only): acknowledges from PIC and returns the vector. + /// In APIC mode: returns the highest injectable from the vCPU's LAPIC. + pub fn acknowledge(&mut self, vcpu_id: u8) -> Option { + if self.apic_mode { + self.lapics[vcpu_id as usize] + .lock() + .unwrap() + .get_highest_injectable() + } else if vcpu_id == 0 { + self.pic.acknowledge() + } else { + None + } + } + + /// Called after the vector has been injected into the vCPU. + /// + /// In APIC mode: moves the vector from IRR to ISR in the vCPU's LAPIC. + /// In PIC mode: no-op (PIC acknowledge already moved to ISR). + pub fn notify_injected(&mut self, vcpu_id: u8, vector: u8) { + if self.apic_mode { + self.lapics[vcpu_id as usize] + .lock() + .unwrap() + .start_of_interrupt(vector); + } + } + + /// Handle an EOI from a specific vCPU's LAPIC. + /// + /// Propagates EOI from LAPIC to IOAPIC for level-triggered interrupt + /// completion. May trigger re-injection if the pin is still asserted. + fn handle_lapic_eoi(&mut self, vcpu_id: u8, vector: u8) { + if let Some(pin) = self.ioapic.end_of_interrupt(vector) { + // Pin still asserted — re-deliver using the correct IOAPIC pin. + if let Some((new_vector, dest)) = self.ioapic.service_irq(pin, true) { + let target = (dest as usize).min(self.lapics.len() - 1); + // Lock-free: atomic OR into shared state. + self.shared_states[target].request_interrupt(new_vector); + } + } + // Suppress unused variable warning — vcpu_id is used for routing context. + let _ = vcpu_id; + } + + /// Tick the LAPIC timer for a specific vCPU. Returns the timer vector if it fired. + pub fn tick_timer(&mut self, vcpu_id: u8, now: Instant) -> Option { + if !self.apic_mode { + return None; + } + let mut lapic = self.lapics[vcpu_id as usize].lock().unwrap(); + if let Some(vector) = lapic.tick_timer(now) { + lapic.accept_interrupt(vector); + Some(vector) + } else { + None + } + } + + /// Handle an MMIO read to an IOAPIC or LAPIC address. + /// + /// Returns Some(value) if the address was handled, None otherwise. + /// LAPIC reads are dispatched to the requesting vCPU's LAPIC. + pub fn handle_mmio_read(&self, vcpu_id: u8, addr: u64, _size: u8) -> Option { + if addr >= IOAPIC_MMIO_BASE && addr < IOAPIC_MMIO_BASE + IOAPIC_MMIO_SIZE { + let offset = addr - IOAPIC_MMIO_BASE; + Some(self.ioapic.read_mmio(offset)) + } else if addr >= LAPIC_MMIO_BASE && addr < LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE { + let offset = addr - LAPIC_MMIO_BASE; + Some( + self.lapics[vcpu_id as usize] + .lock() + .unwrap() + .read_mmio(offset), + ) + } else { + None + } + } + + /// Handle an MMIO write to an IOAPIC or LAPIC address. + /// + /// Returns an `IrqChipWriteResult` indicating whether the address was handled + /// and any IPI action from an ICR write. + pub fn handle_mmio_write( + &mut self, + vcpu_id: u8, + addr: u64, + _size: u8, + data: u32, + ) -> IrqChipWriteResult { + if addr >= IOAPIC_MMIO_BASE && addr < IOAPIC_MMIO_BASE + IOAPIC_MMIO_SIZE { + let offset = addr - IOAPIC_MMIO_BASE; + self.ioapic.write_mmio(offset, data); + // An IOAPIC entry may have been unmasked — check transition. + self.check_apic_transition(); + IrqChipWriteResult { + handled: true, + ipi_action: IpiAction::None, + } + } else if addr >= LAPIC_MMIO_BASE && addr < LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE { + let offset = addr - LAPIC_MMIO_BASE; + let result = self.lapics[vcpu_id as usize] + .lock() + .unwrap() + .write_mmio(offset, data); + + // LAPIC SVR may have been enabled — check transition. + self.check_apic_transition(); + + // Handle EOI propagation to IOAPIC. + if let Some(vector) = result.eoi_vector { + self.handle_lapic_eoi(vcpu_id, vector); + } + + IrqChipWriteResult { + handled: true, + ipi_action: result.ipi_action, + } + } else { + IrqChipWriteResult::default() + } + } + + /// Deliver an IPI to the target LAPIC. + /// + /// Called by the runner when a vCPU's ICR write produces an IPI action + /// that targets another LAPIC (SendInterrupt variant only — INIT and SIPI + /// are handled by the runner's AP startup logic). + pub fn deliver_ipi_interrupt(&mut self, target_apic_id: u8, vector: u8) { + let idx = target_apic_id as usize; + if idx < self.shared_states.len() { + // Lock-free: atomic OR into shared state instead of locking LAPIC. + self.shared_states[idx].request_interrupt(vector); + } + } + + /// Check if conditions are met to switch from PIC to APIC mode. + /// + /// The transition requires BOTH: + /// 1. Any LAPIC is software-enabled (SVR bit 8 set by guest) + /// 2. IOAPIC has at least one unmasked redirection entry + /// + /// This prevents a gap where the kernel has enabled the LAPIC but hasn't + /// yet programmed the IOAPIC entries, which would silently drop interrupts + /// (all IOAPIC entries start masked). + fn check_apic_transition(&mut self) { + if self.apic_mode { + return; + } + let any_lapic_enabled = self.lapics.iter().any(|l| l.lock().unwrap().is_enabled()); + if any_lapic_enabled && self.ioapic.has_unmasked_entries() { + log::info!("APIC mode enabled — LAPIC active + IOAPIC has unmasked entries"); + self.apic_mode = true; + } + } + + /// Get PIC master state for diagnostics. + pub fn pic_master_state(&self) -> (u8, u8, u8, u8) { + self.pic.master_state() + } + + /// Get PIC slave state for diagnostics. + pub fn pic_slave_state(&self) -> (u8, u8, u8, u8) { + self.pic.slave_state() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_irq_chip_starts_in_pic_mode() { + let chip = IrqChip::new(1); + assert!(!chip.apic_mode()); + } + + #[test] + fn test_irq_chip_multi_vcpu_creates_lapics() { + let chip = IrqChip::new(4); + assert_eq!(chip.num_vcpus(), 4); + assert_eq!(chip.lapics[0].lock().unwrap().id(), 0); + assert_eq!(chip.lapics[1].lock().unwrap().id(), 1); + assert_eq!(chip.lapics[2].lock().unwrap().id(), 2); + assert_eq!(chip.lapics[3].lock().unwrap().id(), 3); + } + + #[test] + fn test_irq_chip_pic_mode_raise_irq() { + let mut chip = IrqChip::new(1); + + // Program PIC for testing. + chip.pic.write_port(0x20, 0x11); // ICW1 + chip.pic.write_port(0x21, 0x20); // ICW2: vector base 0x20 + chip.pic.write_port(0x21, 0x04); // ICW3 + chip.pic.write_port(0x21, 0x01); // ICW4 + chip.pic.write_port(0x21, 0x00); // IMR: unmask all + + chip.raise_irq(0); + assert!(chip.has_pending(0)); + + let vector = chip.acknowledge(0); + assert_eq!(vector, Some(0x20)); + } + + #[test] + fn test_irq_chip_pic_mode_only_bsp() { + let mut chip = IrqChip::new(2); + + // Program PIC. + chip.pic.write_port(0x20, 0x11); + chip.pic.write_port(0x21, 0x20); + chip.pic.write_port(0x21, 0x04); + chip.pic.write_port(0x21, 0x01); + chip.pic.write_port(0x21, 0x00); + + chip.raise_irq(0); + + // BSP (vCPU 0) sees the interrupt. + assert!(chip.has_pending(0)); + // AP (vCPU 1) does NOT see PIC interrupts. + assert!(!chip.has_pending(1)); + } + + #[test] + fn test_irq_chip_apic_mode_switch_requires_ioapic_entries() { + let mut chip = IrqChip::new(1); + assert!(!chip.apic_mode()); + + // Write to LAPIC SVR with enable bit — NOT enough alone. + let svr_addr = LAPIC_MMIO_BASE + 0x0F0; + chip.handle_mmio_write(0, svr_addr, 4, 0x1FF); + assert!( + !chip.apic_mode(), + "APIC mode must NOT activate on SVR alone" + ); + + // Unmask an IOAPIC entry (pin 2, vector 0x22) — NOW transition triggers. + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE, 4, 0x14); // Select reg 0x14 (pin 2 low) + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE + 0x10, 4, 0x22); // vector=0x22, unmasked + + assert!( + chip.apic_mode(), + "APIC mode should activate when LAPIC enabled + IOAPIC unmasked" + ); + } + + #[test] + fn test_irq_chip_apic_mode_raise_irq() { + let mut chip = IrqChip::new(1); + + // Enable LAPIC SVR. + chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x0F0, 4, 0x1FF); + + // Configure IOAPIC pin 5: vector 0x25, unmasked, edge-triggered. + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE, 4, 0x1A); // Select register 0x1A (pin 5 low) + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE + 0x10, 4, 0x25); // vector=0x25, unmasked + assert!(chip.apic_mode()); + + chip.raise_irq(5); + // pull_irr: merge shared state into local IRR (lock-free delivery path). + chip.lapics[0] + .lock() + .unwrap() + .pull_irr(&chip.shared_states[0]); + assert!(chip.has_pending(0)); + + let vector = chip.acknowledge(0); + assert_eq!(vector, Some(0x25)); + } + + #[test] + fn test_irq_chip_apic_mode_irq0_remaps_to_gsi2() { + let mut chip = IrqChip::new(1); + + // Enable LAPIC SVR. + chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x0F0, 4, 0x1FF); + + // Configure IOAPIC pin 2: vector 0x22, unmasked, edge-triggered. + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE, 4, 0x14); // Select register 0x14 (pin 2 low) + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE + 0x10, 4, 0x22); // vector=0x22, unmasked + assert!(chip.apic_mode()); + + // raise_irq(0) should remap to IOAPIC pin 2 and deliver vector 0x22. + chip.raise_irq(0); + // pull_irr: merge shared state into local IRR (lock-free delivery path). + chip.lapics[0] + .lock() + .unwrap() + .pull_irr(&chip.shared_states[0]); + assert!(chip.has_pending(0)); + + let vector = chip.acknowledge(0); + assert_eq!(vector, Some(0x22)); + } + + #[test] + fn test_irq_chip_mmio_read_ioapic() { + let mut chip = IrqChip::new(1); + // Read IOAPIC version register. + chip.ioapic.write_mmio(0x00, 0x01); // Direct access to avoid transition check + let version = chip.handle_mmio_read(0, IOAPIC_MMIO_BASE + 0x10, 4); + assert_eq!(version, Some(0x0017_0011)); + } + + #[test] + fn test_irq_chip_mmio_read_lapic() { + let chip = IrqChip::new(1); + let version = chip.handle_mmio_read(0, LAPIC_MMIO_BASE + 0x030, 4); + assert!(version.is_some()); + assert_eq!(version.unwrap() & 0xFF, 0x14); + } + + #[test] + fn test_irq_chip_mmio_read_lapic_id_per_vcpu() { + let chip = IrqChip::new(2); + // vCPU 0 reads its own LAPIC ID. + assert_eq!( + chip.handle_mmio_read(0, LAPIC_MMIO_BASE + 0x020, 4), + Some(0 << 24) + ); + // vCPU 1 reads its own LAPIC ID. + assert_eq!( + chip.handle_mmio_read(1, LAPIC_MMIO_BASE + 0x020, 4), + Some(1 << 24) + ); + } + + #[test] + fn test_irq_chip_mmio_read_unhandled() { + let chip = IrqChip::new(1); + assert_eq!(chip.handle_mmio_read(0, 0xDEAD_0000, 4), None); + } + + #[test] + fn test_irq_chip_mmio_write_unhandled() { + let mut chip = IrqChip::new(1); + let result = chip.handle_mmio_write(0, 0xDEAD_0000, 4, 0); + assert!(!result.handled); + } + + #[test] + fn test_irq_chip_eoi_propagation() { + let mut chip = IrqChip::new(1); + + // Enable LAPIC SVR. + chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x0F0, 4, 0x1FF); + + // Configure IOAPIC pin 3: vector 0x33, level-triggered, unmasked. + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE, 4, 0x16); // register 0x16 = pin 3 low + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE + 0x10, 4, 0x33 | (1 << 15)); // vector=0x33, level-triggered + assert!(chip.apic_mode()); + + // Raise IRQ 3. + chip.raise_irq(3); + // pull_irr: merge shared state into local IRR (lock-free delivery path). + chip.lapics[0] + .lock() + .unwrap() + .pull_irr(&chip.shared_states[0]); + let vector = chip.acknowledge(0); + assert_eq!(vector, Some(0x33)); + + // Inject and acknowledge in LAPIC. + chip.notify_injected(0, 0x33); + + // Write EOI to LAPIC (offset 0x0B0). + chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x0B0, 4, 0); + + // After EOI, the pin is still asserted → re-injection via shared state. + chip.lapics[0] + .lock() + .unwrap() + .pull_irr(&chip.shared_states[0]); + assert!(chip.has_pending(0)); + } + + #[test] + fn test_irq_chip_timer_only_in_apic_mode() { + let mut chip = IrqChip::new(1); + let now = Instant::now(); + // In PIC mode, timer should not fire. + assert_eq!(chip.tick_timer(0, now), None); + } + + #[test] + fn test_irq_chip_notify_injected_pic_mode() { + let mut chip = IrqChip::new(1); + // In PIC mode, notify_injected is a no-op. + chip.notify_injected(0, 0x20); + } + + #[test] + fn test_irq_chip_diagnostics() { + let chip = IrqChip::new(1); + let (irr, isr, imr, vbase) = chip.pic_master_state(); + assert_eq!(irr, 0); + assert_eq!(isr, 0); + assert_eq!(imr, 0xFF); + assert_eq!(vbase, 0); + } + + #[test] + fn test_irq_chip_deliver_ipi_interrupt() { + let mut chip = IrqChip::new(2); + + // Enable APIC mode: enable BSP's LAPIC SVR + unmask IOAPIC entry. + chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x0F0, 4, 0x1FF); + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE, 4, 0x14); + chip.handle_mmio_write(0, IOAPIC_MMIO_BASE + 0x10, 4, 0x22); + assert!(chip.apic_mode()); + + // Deliver IPI to vCPU 1. + chip.deliver_ipi_interrupt(1, 0x40); + // pull_irr: merge shared state into local IRR (lock-free delivery path). + chip.lapics[1] + .lock() + .unwrap() + .pull_irr(&chip.shared_states[1]); + assert!(chip.has_pending(1)); + assert_eq!(chip.acknowledge(1), Some(0x40)); + } + + #[test] + fn test_irq_chip_icr_write_returns_ipi_action() { + let mut chip = IrqChip::new(2); + + // Write ICR high on vCPU 0: destination = APIC 1. + chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x310, 4, 1 << 24); + // Write ICR low on vCPU 0: INIT delivery mode. + let result = chip.handle_mmio_write(0, LAPIC_MMIO_BASE + 0x300, 4, 0x0500); + assert!(result.handled); + assert_eq!(result.ipi_action, IpiAction::SendInit { target_apic_id: 1 }); + } + + #[test] + fn test_irq_chip_default_is_single_vcpu() { + let chip = IrqChip::default(); + assert_eq!(chip.num_vcpus(), 1); + } + + #[test] + fn test_get_lapic_ref_returns_correct_lapic() { + let chip = IrqChip::new(4); + for i in 0..4u32 { + let lapic_ref = chip.get_lapic_ref(i); + assert_eq!(lapic_ref.lock().unwrap().id(), i as u8); + } + } + + #[test] + fn test_concurrent_lapic_access() { + // Verify that per-LAPIC locks allow concurrent access from multiple threads. + let chip = IrqChip::new(4); + let refs: Vec>> = (0..4).map(|i| chip.get_lapic_ref(i)).collect(); + + std::thread::scope(|s| { + for (vcpu_id, lapic_ref) in refs.iter().enumerate() { + let lapic = lapic_ref.clone(); + s.spawn(move || { + // Each thread reads/writes its own LAPIC 1000 times. + for _ in 0..1000 { + let mut l = lapic.lock().unwrap(); + // Read LAPIC ID register (offset 0x020). + let id_val = l.read_mmio(0x020); + assert_eq!(id_val >> 24, vcpu_id as u32); + // Write TPR (offset 0x080). + l.write_mmio(0x080, 0x10); + // Read TPR back. + let tpr = l.read_mmio(0x080); + assert_eq!(tpr, 0x10); + } + }); + } + }); + } + + // ---- Lock-free SharedApicState integration tests ---- + + #[test] + fn test_raise_irq_uses_shared_state() { + use super::lapic::SharedApicState; + + let mut chip = IrqChip::new(2); + // Enable APIC mode: SVR on vCPU 0. + chip.lapics[0].lock().unwrap().write_mmio(0x0F0, 0x1FF); + // Unmask IOAPIC entry 1 → GSI 1, vector 49, dest = LAPIC 0. + chip.ioapic.set_entry(1, 49, 0, false); + chip.apic_mode = true; + + // raise_irq goes through shared state (lock-free). + chip.raise_irq(1); + + // Before pull_irr, LAPIC has nothing. + assert_eq!( + chip.lapics[0].lock().unwrap().get_highest_injectable(), + None + ); + + // After pull_irr, LAPIC sees vector 49. + let shared = chip.get_shared_state(0); + chip.lapics[0].lock().unwrap().pull_irr(&shared); + assert_eq!( + chip.lapics[0].lock().unwrap().get_highest_injectable(), + Some(49) + ); + } + + #[test] + fn test_deliver_ipi_lock_free() { + use super::lapic::SharedApicState; + + let mut chip = IrqChip::new(2); + // Enable both LAPICs. + chip.lapics[0].lock().unwrap().write_mmio(0x0F0, 0x1FF); + chip.lapics[1].lock().unwrap().write_mmio(0x0F0, 0x1FF); + chip.apic_mode = true; + + // Deliver IPI: vector 80 to LAPIC 1. + chip.deliver_ipi_interrupt(1, 80); + + // Before pull_irr, LAPIC 1 has nothing. + assert_eq!( + chip.lapics[1].lock().unwrap().get_highest_injectable(), + None + ); + + // After pull_irr, LAPIC 1 sees vector 80. + let shared = chip.get_shared_state(1); + chip.lapics[1].lock().unwrap().pull_irr(&shared); + assert_eq!( + chip.lapics[1].lock().unwrap().get_highest_injectable(), + Some(80) + ); + + // LAPIC 0 should be unaffected. + let shared0 = chip.get_shared_state(0); + chip.lapics[0].lock().unwrap().pull_irr(&shared0); + assert_eq!( + chip.lapics[0].lock().unwrap().get_highest_injectable(), + None + ); + } +} diff --git a/src/vmm/src/windows/devices/lapic.rs b/src/vmm/src/windows/devices/lapic.rs new file mode 100644 index 000000000..339ee52ed --- /dev/null +++ b/src/vmm/src/windows/devices/lapic.rs @@ -0,0 +1,1042 @@ +//! Local APIC (LAPIC) emulation. +//! +//! Per-vCPU LAPIC for interrupt priority management and IPI delivery. +//! Tracks IRR (Interrupt Request Register) and ISR (In-Service Register) +//! as 256-bit vectors, and implements priority-based interrupt delivery. +//! +//! MMIO interface at 0xFEE0_0000 (4KB region): +//! - 0x020: LAPIC ID +//! - 0x030: LAPIC Version +//! - 0x080: TPR (Task Priority Register) +//! - 0x0B0: EOI (write-only) +//! - 0x0F0: SVR (Spurious Vector Register) +//! - 0x100-0x170: ISR (read-only, 256 bits) +//! - 0x200-0x270: IRR (read-only, 256 bits) +//! - 0x300: ICR Low (Interrupt Command Register) +//! - 0x310: ICR High (destination APIC ID) +//! - 0x320: LVT Timer +//! - 0x380: Timer Initial Count +//! - 0x390: Timer Current Count +//! - 0x3E0: Timer Divide Configuration + +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Instant; + +/// Shared APIC state for lock-free cross-vCPU interrupt delivery. +/// +/// Other vCPUs atomically OR bits into `new_irr`. The owning vCPU +/// periodically calls `pull_irr()` to merge into its local IRR. +/// Inspired by OpenVMM's `virt_support_apic::SharedState`. +pub struct SharedApicState { + /// Remote interrupt requests (256 bits = 8 x AtomicU32). + /// Source vCPUs atomic-OR the vector bit here. + new_irr: [AtomicU32; 8], +} + +impl SharedApicState { + /// Create a new shared state with no pending interrupts. + pub fn new() -> Self { + Self { + new_irr: std::array::from_fn(|_| AtomicU32::new(0)), + } + } + + /// Atomically request an interrupt vector on this vCPU. + /// + /// Returns `true` if the bit was newly set (caller should wake target vCPU). + pub fn request_interrupt(&self, vector: u8) -> bool { + let (bank, mask) = bank_mask(vector); + let prev = self.new_irr[bank].fetch_or(mask, Ordering::Release); + prev & mask == 0 + } +} + +/// Compute the bank index and bit mask for a vector (0-255). +fn bank_mask(vector: u8) -> (usize, u32) { + let bank = (vector / 32) as usize; + let bit = vector % 32; + (bank, 1u32 << bit) +} + +/// Action resulting from an ICR write (Inter-Processor Interrupt). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IpiAction { + /// No IPI action (non-ICR write or unrecognized delivery mode). + None, + /// Fixed delivery: send interrupt vector to target LAPIC. + SendInterrupt { target_apic_id: u8, vector: u8 }, + /// Broadcast fixed interrupt to all vCPUs except the sender. + BroadcastInterrupt { source_apic_id: u8, vector: u8 }, + /// INIT delivery: reset target processor. + SendInit { target_apic_id: u8 }, + /// Startup IPI (SIPI): start target processor at vector * 0x1000. + SendSipi { target_apic_id: u8, vector: u8 }, +} + +/// Result of a LAPIC MMIO write operation. +#[derive(Debug, Clone, Copy)] +pub struct LapicWriteResult { + /// If an EOI was written, the vector that was cleared from ISR. + pub eoi_vector: Option, + /// If an ICR was written, the resulting IPI action. + pub ipi_action: IpiAction, +} + +impl Default for LapicWriteResult { + fn default() -> Self { + Self { + eoi_vector: None, + ipi_action: IpiAction::None, + } + } +} + +/// LAPIC version: integrated APIC with 6 LVT entries. +const LAPIC_VERSION: u32 = 0x0005_0014; // version 0x14, max LVT=5 + +/// SVR bit 8: APIC software enable. +const SVR_APIC_ENABLE: u32 = 1 << 8; + +/// LVT mask bit (bit 16). +const LVT_MASKED: u32 = 1 << 16; + +/// Timer modes. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TimerMode { + OneShot, + Periodic, +} + +/// Per-vCPU Local APIC. +pub struct LocalApic { + /// APIC ID (matches vCPU index). + id: u8, + /// 256-bit Interrupt Request Register (8 x 32-bit words). + irr: [u32; 8], + /// 256-bit In-Service Register. + isr: [u32; 8], + /// Task Priority Register (only low 8 bits used). + tpr: u8, + /// Spurious Vector Register (bit 8 = APIC enabled). + svr: u32, + + // ICR (Interrupt Command Register) for IPI support. + /// ICR low 32 bits (vector, delivery mode, destination shorthand). + icr_low: u32, + /// ICR high 32 bits (destination APIC ID in bits 31:24). + icr_high: u32, + + // Timer state. + /// Timer mode. + timer_mode: TimerMode, + /// LVT Timer vector. + timer_vector: u8, + /// LVT Timer mask. + timer_masked: bool, + /// Divide configuration register value. + timer_divide_reg: u32, + /// Computed divisor (1, 2, 4, 8, 16, 32, 64, 128). + timer_divisor: u32, + /// Initial count register. + timer_initial: u32, + /// When the timer fires next (host time). + timer_deadline: Option, + /// Timer period for periodic mode. + timer_period_ns: u64, +} + +impl Default for LocalApic { + fn default() -> Self { + Self::new() + } +} + +impl LocalApic { + /// Create a new LAPIC with default state (disabled), APIC ID = 0. + pub fn new() -> Self { + Self::new_with_id(0) + } + + /// Create a new LAPIC with a specific APIC ID (disabled by default). + pub fn new_with_id(id: u8) -> Self { + Self { + id, + irr: [0; 8], + isr: [0; 8], + tpr: 0, + svr: 0, // APIC disabled by default + + icr_low: 0, + icr_high: 0, + + timer_mode: TimerMode::OneShot, + timer_vector: 0, + timer_masked: true, + timer_divide_reg: 0, + timer_divisor: 2, // Default divisor + timer_initial: 0, + timer_deadline: None, + timer_period_ns: 0, + } + } + + /// Get the APIC ID. + pub fn id(&self) -> u8 { + self.id + } + + /// Whether the LAPIC is software-enabled (SVR bit 8). + pub fn is_enabled(&self) -> bool { + self.svr & SVR_APIC_ENABLE != 0 + } + + /// Pull remote interrupt requests from the shared state into the local IRR. + /// + /// Atomically swaps each bank to 0 and ORs the bits into the local IRR. + /// Called at the top of each vCPU loop iteration (lock-free fast path). + pub fn pull_irr(&mut self, shared: &SharedApicState) { + for i in 0..8 { + let bits = shared.new_irr[i].swap(0, Ordering::Acquire); + if bits != 0 { + self.irr[i] |= bits; + } + } + } + + /// Accept an interrupt vector into the IRR. + pub fn accept_interrupt(&mut self, vector: u8) { + let word = (vector / 32) as usize; + let bit = vector % 32; + self.irr[word] |= 1 << bit; + } + + /// Get the highest-priority vector in IRR that beats the current + /// Processor Priority (PPR = max(TPR, highest ISR vector class)). + /// + /// Returns None if no injectable vector exists. + pub fn get_highest_injectable(&self) -> Option { + let highest_irr = Self::highest_bit(&self.irr)?; + let ppr = self.processor_priority(); + + // Vector must have higher priority (higher number = higher priority + // in x86, but within the same class, lower bit wins; the priority + // class is vector >> 4). + if (highest_irr >> 4) > (ppr >> 4) { + Some(highest_irr) + } else { + None + } + } + + /// Called when the vector is actually injected into the vCPU. + /// Moves the vector from IRR to ISR. + pub fn start_of_interrupt(&mut self, vector: u8) { + let word = (vector / 32) as usize; + let bit = vector % 32; + self.irr[word] &= !(1 << bit); + self.isr[word] |= 1 << bit; + } + + /// Handle End-of-Interrupt. + /// + /// Clears the highest-priority ISR bit. + /// Returns the vector that was cleared (for IOAPIC EOI broadcast). + pub fn end_of_interrupt(&mut self) -> Option { + let highest = Self::highest_bit(&self.isr)?; + let word = (highest / 32) as usize; + let bit = highest % 32; + self.isr[word] &= !(1 << bit); + Some(highest) + } + + /// Compute the LAPIC timer current count register value. + /// + /// Returns the remaining count based on elapsed time since the timer was + /// armed. The kernel reads this during timer calibration and busy-waits. + fn current_count(&self) -> u32 { + if let Some(deadline) = self.timer_deadline { + let now = Instant::now(); + if now < deadline { + let remaining_ns = deadline.duration_since(now).as_nanos() as u64; + let tick_ns = 100 * self.timer_divisor as u64; + let remaining_ticks = remaining_ns / tick_ns; + (remaining_ticks as u32).min(self.timer_initial) + } else { + 0 + } + } else { + 0 + } + } + + /// Tick the LAPIC timer. Returns the timer vector if it fired. + pub fn tick_timer(&mut self, now: Instant) -> Option { + if self.timer_masked || self.timer_initial == 0 { + return None; + } + + let deadline = self.timer_deadline?; + + if now >= deadline { + let vector = self.timer_vector; + + match self.timer_mode { + TimerMode::OneShot => { + self.timer_deadline = None; + } + TimerMode::Periodic => { + // Rearm timer for next period. + let period = std::time::Duration::from_nanos(self.timer_period_ns); + self.timer_deadline = Some(deadline + period); + } + } + + Some(vector) + } else { + None + } + } + + /// Read from the LAPIC MMIO region. + pub fn read_mmio(&self, offset: u64) -> u32 { + match offset { + 0x020 => (self.id as u32) << 24, // LAPIC ID + 0x030 => LAPIC_VERSION, // Version + 0x080 => self.tpr as u32, // TPR + 0x0B0 => 0, // EOI (write-only) + 0x0F0 => self.svr, // SVR + // ISR: 0x100, 0x110, 0x120, ..., 0x170 + 0x100..=0x170 if offset & 0x0F == 0 => { + let idx = ((offset - 0x100) / 0x10) as usize; + if idx < 8 { + self.isr[idx] + } else { + 0 + } + } + // IRR: 0x200, 0x210, 0x220, ..., 0x270 + 0x200..=0x270 if offset & 0x0F == 0 => { + let idx = ((offset - 0x200) / 0x10) as usize; + if idx < 8 { + self.irr[idx] + } else { + 0 + } + } + 0x300 => self.icr_low, // ICR Low + 0x310 => self.icr_high, // ICR High + 0x320 => self.read_lvt_timer(), // LVT Timer + 0x380 => self.timer_initial, // Timer Initial Count + 0x390 => self.current_count(), // Timer Current Count + 0x3E0 => self.timer_divide_reg, // Timer Divide Configuration + _ => 0, + } + } + + /// Result of a LAPIC MMIO write. + /// + /// Contains an optional EOI vector and an IPI action from ICR writes. + pub fn write_mmio(&mut self, offset: u64, value: u32) -> LapicWriteResult { + match offset { + 0x080 => { + self.tpr = (value & 0xFF) as u8; + LapicWriteResult::default() + } + 0x0B0 => { + // EOI: clear highest ISR, return vector for IOAPIC. + LapicWriteResult { + eoi_vector: self.end_of_interrupt(), + ipi_action: IpiAction::None, + } + } + 0x0F0 => { + self.svr = value; + log::debug!( + "LAPIC {} SVR write: {:#X} (enabled={})", + self.id, + value, + value & SVR_APIC_ENABLE != 0 + ); + LapicWriteResult::default() + } + 0x300 => { + // ICR Low write triggers IPI delivery. + self.icr_low = value; + let action = self.parse_icr(); + LapicWriteResult { + eoi_vector: None, + ipi_action: action, + } + } + 0x310 => { + // ICR High: destination APIC ID (bits 31:24). + self.icr_high = value; + LapicWriteResult::default() + } + 0x320 => { + self.write_lvt_timer(value); + LapicWriteResult::default() + } + 0x380 => { + self.write_initial_count(value); + LapicWriteResult::default() + } + 0x3E0 => { + self.write_divide_config(value); + LapicWriteResult::default() + } + _ => LapicWriteResult::default(), + } + } + + /// Parse the ICR low/high registers to produce an IPI action. + /// + /// ICR Low bits: + /// - [7:0] Vector + /// - [10:8] Delivery mode (000=Fixed, 101=INIT, 110=SIPI) + /// - [11] Destination mode (0=physical, 1=logical) + /// - [17:12] Reserved/status + /// - [19:18] Destination shorthand (00=none, 01=self, 10=all-incl-self, 11=all-excl-self) + fn parse_icr(&self) -> IpiAction { + let vector = (self.icr_low & 0xFF) as u8; + let delivery_mode = (self.icr_low >> 8) & 0x7; + let dest_shorthand = (self.icr_low >> 18) & 0x3; + let dest_apic_id = ((self.icr_high >> 24) & 0xFF) as u8; + + // Handle destination shorthand first. + match dest_shorthand { + 0b01 => { + // Self: send to own LAPIC (used for self-IPI). + log::debug!("LAPIC {} ICR: Self IPI vector={:#X}", self.id, vector); + return IpiAction::SendInterrupt { + target_apic_id: self.id, + vector, + }; + } + 0b10 | 0b11 => { + // All Including Self (0b10) or All Excluding Self (0b11). + // For fixed delivery, broadcast to all other vCPUs. + if delivery_mode == 0b000 { + log::debug!( + "LAPIC {} ICR: Broadcast vector={:#X} (shorthand={})", + self.id, + vector, + if dest_shorthand == 0b10 { + "all-incl" + } else { + "all-excl" + } + ); + return IpiAction::BroadcastInterrupt { + source_apic_id: self.id, + vector, + }; + } + // Non-fixed broadcast (e.g., INIT to all) — fallthrough to per-target. + // For now, treat as no-op (Linux doesn't broadcast INIT/SIPI with shorthand). + log::debug!( + "LAPIC {} ICR: Broadcast delivery_mode={} (unsupported, ignored)", + self.id, + delivery_mode + ); + return IpiAction::None; + } + _ => { + // 0b00: No shorthand — use destination field (normal path). + } + } + + match delivery_mode { + 0b000 => { + // Fixed delivery. + log::debug!( + "LAPIC {} ICR: Fixed interrupt vector={:#X} → APIC {}", + self.id, + vector, + dest_apic_id + ); + IpiAction::SendInterrupt { + target_apic_id: dest_apic_id, + vector, + } + } + 0b101 => { + // INIT delivery. + log::debug!("LAPIC {} ICR: INIT → APIC {}", self.id, dest_apic_id); + IpiAction::SendInit { + target_apic_id: dest_apic_id, + } + } + 0b110 => { + // Startup IPI (SIPI). + log::debug!( + "LAPIC {} ICR: SIPI vector={:#X} → APIC {} (start at {:#X})", + self.id, + vector, + dest_apic_id, + (vector as u32) * 0x1000 + ); + IpiAction::SendSipi { + target_apic_id: dest_apic_id, + vector, + } + } + _ => { + log::debug!( + "LAPIC {} ICR: unsupported delivery mode {} → APIC {}", + self.id, + delivery_mode, + dest_apic_id + ); + IpiAction::None + } + } + } + + /// Compute Processor Priority Register (PPR). + /// + /// PPR = max(TPR, highest ISR priority class) — determines the minimum + /// priority class that can be delivered. + fn processor_priority(&self) -> u8 { + let isr_class = Self::highest_bit(&self.isr).map(|v| v & 0xF0).unwrap_or(0); + let tpr_class = self.tpr & 0xF0; + std::cmp::max(isr_class, tpr_class) + } + + /// Find the highest set bit across an 8-word (256-bit) register. + /// Returns the bit index (0-255) or None if all zero. + fn highest_bit(reg: &[u32; 8]) -> Option { + for word_idx in (0..8).rev() { + let word = reg[word_idx]; + if word != 0 { + let bit = 31 - word.leading_zeros(); + return Some((word_idx as u8) * 32 + bit as u8); + } + } + None + } + + /// Read the LVT Timer register. + fn read_lvt_timer(&self) -> u32 { + let mut val = self.timer_vector as u32; + if self.timer_masked { + val |= LVT_MASKED; + } + if self.timer_mode == TimerMode::Periodic { + val |= 1 << 17; + } + val + } + + /// Write the LVT Timer register. + fn write_lvt_timer(&mut self, value: u32) { + self.timer_vector = (value & 0xFF) as u8; + self.timer_masked = value & LVT_MASKED != 0; + self.timer_mode = if value & (1 << 17) != 0 { + TimerMode::Periodic + } else { + TimerMode::OneShot + }; + } + + /// Write the Timer Initial Count register. + fn write_initial_count(&mut self, value: u32) { + self.timer_initial = value; + if value == 0 { + self.timer_deadline = None; + return; + } + + // Compute timer period: initial_count * divisor * base_period. + // Base period is ~100ns (approximation of bus clock period). + // This gives reasonable timer behavior for Linux's LAPIC timer driver. + let ticks = value as u64 * self.timer_divisor as u64; + self.timer_period_ns = ticks * 100; // ~100ns per bus clock tick + let period = std::time::Duration::from_nanos(self.timer_period_ns); + self.timer_deadline = Some(Instant::now() + period); + } + + /// Write the Timer Divide Configuration register. + fn write_divide_config(&mut self, value: u32) { + self.timer_divide_reg = value & 0x0B; // Only bits 0,1,3 are used. + // Decode divisor: bits [3,1,0] encode the divisor. + // 0b000=2, 0b001=4, 0b010=8, 0b011=16, + // 0b100=32, 0b101=64, 0b110=128, 0b111=1 + let div_bits = ((value & 0x08) >> 1) | (value & 0x03); + self.timer_divisor = match div_bits { + 0b000 => 2, + 0b001 => 4, + 0b010 => 8, + 0b011 => 16, + 0b100 => 32, + 0b101 => 64, + 0b110 => 128, + 0b111 => 1, + _ => 2, + }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lapic_initial_state() { + let lapic = LocalApic::new(); + assert_eq!(lapic.id, 0); + assert!(!lapic.is_enabled()); + assert_eq!(lapic.tpr, 0); + assert!(lapic.timer_masked); + } + + #[test] + fn test_lapic_enable_via_svr() { + let mut lapic = LocalApic::new(); + lapic.write_mmio(0x0F0, SVR_APIC_ENABLE | 0xFF); + assert!(lapic.is_enabled()); + } + + #[test] + fn test_lapic_accept_and_get_injectable() { + let mut lapic = LocalApic::new(); + + // Accept vector 0x30. + lapic.accept_interrupt(0x30); + assert_eq!(lapic.get_highest_injectable(), Some(0x30)); + } + + #[test] + fn test_lapic_priority_ordering() { + let mut lapic = LocalApic::new(); + + // Accept vectors 0x30 and 0x50 — 0x50 has higher priority. + lapic.accept_interrupt(0x30); + lapic.accept_interrupt(0x50); + assert_eq!(lapic.get_highest_injectable(), Some(0x50)); + } + + #[test] + fn test_lapic_isr_blocks_lower_priority() { + let mut lapic = LocalApic::new(); + + // Put vector 0x50 in service. + lapic.accept_interrupt(0x50); + lapic.start_of_interrupt(0x50); + + // Accept vector 0x30 — lower priority class, should be blocked. + lapic.accept_interrupt(0x30); + assert_eq!(lapic.get_highest_injectable(), None); + + // Accept vector 0x60 — higher priority, should be injectable. + lapic.accept_interrupt(0x60); + assert_eq!(lapic.get_highest_injectable(), Some(0x60)); + } + + #[test] + fn test_lapic_tpr_blocks_low_priority() { + let mut lapic = LocalApic::new(); + + // Set TPR to class 5 (0x50) — blocks vectors 0x00-0x5F. + lapic.write_mmio(0x080, 0x50); + + lapic.accept_interrupt(0x30); + assert_eq!(lapic.get_highest_injectable(), None); + + lapic.accept_interrupt(0x60); + assert_eq!(lapic.get_highest_injectable(), Some(0x60)); + } + + #[test] + fn test_lapic_start_of_interrupt() { + let mut lapic = LocalApic::new(); + + lapic.accept_interrupt(0x30); + assert!(lapic.irr[1] & (1 << 16) != 0); // 0x30 = word 1, bit 16 + + lapic.start_of_interrupt(0x30); + assert_eq!(lapic.irr[1] & (1 << 16), 0, "IRR should be cleared"); + assert!(lapic.isr[1] & (1 << 16) != 0, "ISR should be set"); + } + + #[test] + fn test_lapic_eoi_clears_isr() { + let mut lapic = LocalApic::new(); + + lapic.accept_interrupt(0x30); + lapic.start_of_interrupt(0x30); + + let vector = lapic.end_of_interrupt(); + assert_eq!(vector, Some(0x30)); + assert_eq!(lapic.isr[1] & (1 << 16), 0, "ISR should be cleared"); + } + + #[test] + fn test_lapic_eoi_clears_highest_isr() { + let mut lapic = LocalApic::new(); + + // Put two vectors in service. + lapic.accept_interrupt(0x30); + lapic.start_of_interrupt(0x30); + lapic.accept_interrupt(0x50); + lapic.start_of_interrupt(0x50); + + // EOI clears highest (0x50). + let vector = lapic.end_of_interrupt(); + assert_eq!(vector, Some(0x50)); + + // Next EOI clears 0x30. + let vector = lapic.end_of_interrupt(); + assert_eq!(vector, Some(0x30)); + } + + #[test] + fn test_lapic_eoi_empty_isr() { + let mut lapic = LocalApic::new(); + assert_eq!(lapic.end_of_interrupt(), None); + } + + #[test] + fn test_lapic_mmio_read_id() { + let lapic = LocalApic::new(); + assert_eq!(lapic.read_mmio(0x020), 0); // ID = 0, shifted left 24 + } + + #[test] + fn test_lapic_mmio_read_version() { + let lapic = LocalApic::new(); + let version = lapic.read_mmio(0x030); + assert_eq!(version & 0xFF, 0x14); + } + + #[test] + fn test_lapic_mmio_svr_roundtrip() { + let mut lapic = LocalApic::new(); + lapic.write_mmio(0x0F0, 0x1FF); + assert_eq!(lapic.read_mmio(0x0F0), 0x1FF); + } + + #[test] + fn test_lapic_mmio_tpr_roundtrip() { + let mut lapic = LocalApic::new(); + lapic.write_mmio(0x080, 0x40); + assert_eq!(lapic.read_mmio(0x080), 0x40); + } + + #[test] + fn test_lapic_mmio_eoi_returns_vector() { + let mut lapic = LocalApic::new(); + lapic.accept_interrupt(0x30); + lapic.start_of_interrupt(0x30); + + let result = lapic.write_mmio(0x0B0, 0); + assert_eq!(result.eoi_vector, Some(0x30)); + } + + #[test] + fn test_lapic_mmio_isr_read() { + let mut lapic = LocalApic::new(); + lapic.accept_interrupt(0x30); + lapic.start_of_interrupt(0x30); + + // 0x30 = word 1 (offset 0x110) + assert_ne!(lapic.read_mmio(0x110), 0); + assert_eq!(lapic.read_mmio(0x100), 0); // Word 0 should be empty. + } + + #[test] + fn test_lapic_mmio_irr_read() { + let mut lapic = LocalApic::new(); + lapic.accept_interrupt(0x30); + + // 0x30 = word 1 (offset 0x210) + assert_ne!(lapic.read_mmio(0x210), 0); + assert_eq!(lapic.read_mmio(0x200), 0); + } + + #[test] + fn test_lapic_lvt_timer_write_read() { + let mut lapic = LocalApic::new(); + + // Set timer: vector=0x20, periodic, unmasked. + let lvt = 0x20 | (1 << 17); // vector=0x20, periodic + lapic.write_mmio(0x320, lvt); + + let read = lapic.read_mmio(0x320); + assert_eq!(read & 0xFF, 0x20); + assert!(read & (1 << 17) != 0, "periodic bit"); + assert!(read & LVT_MASKED == 0, "should be unmasked"); + } + + #[test] + fn test_lapic_timer_divide_config() { + let mut lapic = LocalApic::new(); + + // Divisor = 1 (bits [3,1,0] = 0b111 → register value = 0b1011 = 0x0B) + lapic.write_mmio(0x3E0, 0x0B); + assert_eq!(lapic.timer_divisor, 1); + + // Divisor = 16 (bits [3,1,0] = 0b011 → register value = 0b0011 = 0x03) + lapic.write_mmio(0x3E0, 0x03); + assert_eq!(lapic.timer_divisor, 16); + } + + #[test] + fn test_lapic_timer_fires_oneshot() { + let mut lapic = LocalApic::new(); + + // Configure: vector=0x20, oneshot, unmasked, divisor=1 + lapic.write_mmio(0x320, 0x20); // vector=0x20, oneshot, unmasked + lapic.write_mmio(0x3E0, 0x0B); // divisor=1 + + // Set initial count → arms the timer. + lapic.write_mmio(0x380, 1); // count=1 + + // Timer should fire after some time. + let future = Instant::now() + std::time::Duration::from_millis(100); + let vector = lapic.tick_timer(future); + assert_eq!(vector, Some(0x20)); + + // Second tick should not fire (oneshot). + let vector = lapic.tick_timer(future + std::time::Duration::from_millis(100)); + assert_eq!(vector, None); + } + + #[test] + fn test_lapic_timer_masked_no_fire() { + let mut lapic = LocalApic::new(); + + // Configure: masked + lapic.write_mmio(0x320, 0x20 | LVT_MASKED); + lapic.write_mmio(0x3E0, 0x0B); + lapic.write_mmio(0x380, 1); + + let future = Instant::now() + std::time::Duration::from_millis(100); + assert_eq!(lapic.tick_timer(future), None); + } + + #[test] + fn test_lapic_timer_zero_count_disarms() { + let mut lapic = LocalApic::new(); + + lapic.write_mmio(0x320, 0x20); + lapic.write_mmio(0x3E0, 0x0B); + lapic.write_mmio(0x380, 0); // count=0 disarms + + let future = Instant::now() + std::time::Duration::from_millis(100); + assert_eq!(lapic.tick_timer(future), None); + } + + #[test] + fn test_lapic_highest_bit() { + let mut reg = [0u32; 8]; + assert_eq!(LocalApic::highest_bit(®), None); + + reg[0] = 1; // bit 0 + assert_eq!(LocalApic::highest_bit(®), Some(0)); + + reg[7] = 1 << 31; // bit 255 + assert_eq!(LocalApic::highest_bit(®), Some(255)); + + reg[3] = 1 << 16; // bit 112 + // Highest should still be 255. + assert_eq!(LocalApic::highest_bit(®), Some(255)); + } + + #[test] + fn test_lapic_processor_priority() { + let mut lapic = LocalApic::new(); + + // No ISR, TPR=0 → PPR=0. + assert_eq!(lapic.processor_priority(), 0); + + // TPR=0x40 → PPR=0x40. + lapic.tpr = 0x40; + assert_eq!(lapic.processor_priority(), 0x40); + + // ISR has 0x50 → PPR=max(0x40, 0x50)=0x50. + lapic.accept_interrupt(0x50); + lapic.start_of_interrupt(0x50); + assert_eq!(lapic.processor_priority(), 0x50); + } + + #[test] + fn test_lapic_mmio_invalid_offset() { + let mut lapic = LocalApic::new(); + assert_eq!(lapic.read_mmio(0x400), 0); + let result = lapic.write_mmio(0x400, 0xDEAD); + assert_eq!(result.eoi_vector, None); + assert_eq!(result.ipi_action, IpiAction::None); + } + + #[test] + fn test_lapic_mmio_isr_non_aligned() { + let lapic = LocalApic::new(); + // Non-16-byte-aligned ISR offset should return 0. + assert_eq!(lapic.read_mmio(0x104), 0); + } + + // ---- ICR / IPI tests ---- + + #[test] + fn test_lapic_new_with_id() { + let lapic = LocalApic::new_with_id(3); + assert_eq!(lapic.id(), 3); + assert_eq!(lapic.read_mmio(0x020), 3 << 24); + assert!(!lapic.is_enabled()); + } + + #[test] + fn test_lapic_icr_read_write_roundtrip() { + let mut lapic = LocalApic::new(); + + // Write ICR high (destination APIC ID = 1). + lapic.write_mmio(0x310, 1 << 24); + assert_eq!(lapic.read_mmio(0x310), 1 << 24); + + // Write ICR low (vector=0x40, Fixed delivery). + let result = lapic.write_mmio(0x300, 0x40); + assert_eq!(lapic.read_mmio(0x300), 0x40); + + match result.ipi_action { + IpiAction::SendInterrupt { + target_apic_id, + vector, + } => { + assert_eq!(target_apic_id, 1); + assert_eq!(vector, 0x40); + } + other => panic!("expected SendInterrupt, got {:?}", other), + } + } + + #[test] + fn test_lapic_icr_init_delivery() { + let mut lapic = LocalApic::new(); + + // Set destination = APIC 2. + lapic.write_mmio(0x310, 2 << 24); + // ICR low: delivery mode = 0b101 (INIT), vector ignored. + let result = lapic.write_mmio(0x300, 0x0500); + + match result.ipi_action { + IpiAction::SendInit { target_apic_id } => { + assert_eq!(target_apic_id, 2); + } + other => panic!("expected SendInit, got {:?}", other), + } + } + + #[test] + fn test_lapic_icr_sipi_delivery() { + let mut lapic = LocalApic::new(); + + // Set destination = APIC 1. + lapic.write_mmio(0x310, 1 << 24); + // ICR low: delivery mode = 0b110 (SIPI), vector = 0x10. + // Start address = 0x10 * 0x1000 = 0x10000. + let result = lapic.write_mmio(0x300, 0x0600 | 0x10); + + match result.ipi_action { + IpiAction::SendSipi { + target_apic_id, + vector, + } => { + assert_eq!(target_apic_id, 1); + assert_eq!(vector, 0x10); + } + other => panic!("expected SendSipi, got {:?}", other), + } + } + + #[test] + fn test_lapic_icr_unsupported_delivery_mode() { + let mut lapic = LocalApic::new(); + lapic.write_mmio(0x310, 1 << 24); + // Delivery mode = 0b010 (SMI) — not supported. + let result = lapic.write_mmio(0x300, 0x0200); + assert_eq!(result.ipi_action, IpiAction::None); + } + + #[test] + fn test_lapic_non_icr_write_returns_no_ipi() { + let mut lapic = LocalApic::new(); + // SVR write should produce no IPI. + let result = lapic.write_mmio(0x0F0, 0x1FF); + assert_eq!(result.ipi_action, IpiAction::None); + assert_eq!(result.eoi_vector, None); + } + + // ---- SharedApicState tests ---- + + #[test] + fn test_shared_request_interrupt() { + let shared = SharedApicState::new(); + // Vector 32 → bank 1, bit 0. + assert!(shared.request_interrupt(32)); // first set → true + assert!(!shared.request_interrupt(32)); // already set → false + // Vector 33 → bank 1, bit 1. + assert!(shared.request_interrupt(33)); // different bit → true + } + + #[test] + fn test_shared_pull_irr() { + let shared = SharedApicState::new(); + let mut lapic = LocalApic::new(); + + shared.request_interrupt(48); // bank 1, bit 16 + shared.request_interrupt(100); // bank 3, bit 4 + + lapic.pull_irr(&shared); + + // After pull, shared should be cleared. + assert!(shared.request_interrupt(48)); // re-setting returns true (was cleared) + + // LAPIC should now have vector 100 injectable (highest). + // Enable LAPIC first (SVR bit 8). + lapic.write_mmio(0x0F0, 0x1FF); + assert_eq!(lapic.get_highest_injectable(), Some(100)); + } + + #[test] + fn test_shared_concurrent_ipi() { + use std::sync::Arc; + + let shared = Arc::new(SharedApicState::new()); + let num_threads = 8; + // Each thread sets a distinct vector: 32, 33, ..., 39. + std::thread::scope(|s| { + for t in 0..num_threads { + let sh = shared.clone(); + let vector = 32 + t as u8; + s.spawn(move || { + assert!(sh.request_interrupt(vector)); + }); + } + }); + + // Pull all into a LAPIC and verify all vectors present. + let mut lapic = LocalApic::new(); + lapic.pull_irr(&shared); + lapic.write_mmio(0x0F0, 0x1FF); // enable + // Highest should be 39. + assert_eq!(lapic.get_highest_injectable(), Some(39)); + } + + #[test] + fn test_pull_irr_priority() { + let shared = SharedApicState::new(); + let mut lapic = LocalApic::new(); + + shared.request_interrupt(64); // lower priority + shared.request_interrupt(200); // higher priority + + lapic.pull_irr(&shared); + lapic.write_mmio(0x0F0, 0x1FF); // enable + assert_eq!(lapic.get_highest_injectable(), Some(200)); + + // Acknowledge 200, next should be 64. + lapic.start_of_interrupt(200); + assert_eq!(lapic.get_highest_injectable(), Some(64)); + } +} diff --git a/src/vmm/src/windows/devices/manager.rs b/src/vmm/src/windows/devices/manager.rs new file mode 100644 index 000000000..9fe05f990 --- /dev/null +++ b/src/vmm/src/windows/devices/manager.rs @@ -0,0 +1,1156 @@ +//! DeviceManager — centralized I/O port and MMIO device dispatch. +//! +//! Owns all emulated devices (Serial, PIC, PIT, CMOS/RTC, virtio-*) +//! and routes vCPU exit events to the appropriate device handlers. + +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::{Arc, LazyLock, Mutex}; +use std::time::Instant; + +use super::super::cmdline::{irq_for_slot, mmio_base_for_slot, MmioSlot, MMIO_SLOT_SIZE}; +use super::super::context::VmContext; +use super::super::error::{Result, WkrunError}; +use super::super::vcpu::IoHandler; +use super::irq_chip::IrqChip; +use super::lapic::{IpiAction, LocalApic, SharedApicState}; +use super::pit::Pit; +use super::serial::{Serial, COM1_BASE}; +use super::virtio::balloon::VirtioBalloon; +use super::virtio::block::VirtioBlock; +use super::virtio::disk::open_disk_backend; +use super::virtio::mmio::VirtioMmioDevice; +use super::virtio::net::VirtioNet; +use super::virtio::p9::Virtio9p; +use super::virtio::queue::GuestMemoryAccessor; +use super::virtio::rng::VirtioRng; +use super::virtio::vsock::VirtioVsock; + +/// Shared console output buffer. +pub type ConsoleBuffer = Arc>>; + +/// Writer that copies output to both an inner writer and a shared buffer. +struct TeeWriter { + inner: Box, + buffer: ConsoleBuffer, +} + +impl Write for TeeWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buffer.lock().unwrap().extend_from_slice(buf); + self.inner.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.inner.flush() + } +} + +/// Global console output buffers, keyed by ctx_id. +static CONSOLE_BUFFERS: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Store a console buffer for a VM. +pub fn store_console_buffer(ctx_id: u32, buffer: ConsoleBuffer) { + CONSOLE_BUFFERS.lock().unwrap().insert(ctx_id, buffer); +} + +/// Get a snapshot of console output for a VM. +/// +/// Returns None if no buffer exists for the given ctx_id. +pub fn get_console_output(ctx_id: u32) -> Option> { + CONSOLE_BUFFERS + .lock() + .unwrap() + .get(&ctx_id) + .map(|buf| buf.lock().unwrap().clone()) +} + +/// Remove and drop the console buffer for a VM. +pub fn remove_console_buffer(ctx_id: u32) { + CONSOLE_BUFFERS.lock().unwrap().remove(&ctx_id); +} + +/// Default guest CID for vsock (standard value for single-VM hosts). +const GUEST_CID: u64 = 3; + +/// ACPI PM1a event block base port (4 bytes wide). +const PM1A_EVT_BLK: u16 = 0x600; + +/// ACPI PM1a control block base port (2 bytes wide). +const PM1A_CNT_BLK: u16 = 0x604; + +/// Default vsock listen ports (BoxLite: 2695=gRPC, 2696=ready signal). + +/// Convert a value to BCD (Binary-Coded Decimal). +/// E.g. 26 → 0x26, 59 → 0x59. +fn to_bcd(val: u8) -> u8 { + ((val / 10) << 4) | (val % 10) +} + +/// Snapshot of host UTC time, captured once at VM start and stored as +/// BCD values for CMOS register reads. +struct CmosTime { + seconds: u8, + minutes: u8, + hours: u8, + day_of_week: u8, + day_of_month: u8, + month: u8, + year: u8, // Two-digit year in BCD (e.g. 0x26 for 2026) + century: u8, // Century in BCD (e.g. 0x20) +} + +impl CmosTime { + /// Capture the current host UTC time. + fn now() -> Self { + use std::time::{SystemTime, UNIX_EPOCH}; + + let secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Break Unix timestamp into calendar components. + // Algorithm from Howard Hinnant's chrono-compatible date library. + let days = (secs / 86400) as i64; + let time_of_day = secs % 86400; + + // Civil date from days since epoch (March-based, then adjusted). + let z = days + 719468; + let era = if z >= 0 { z } else { z - 146096 } / 146097; + let doe = (z - era * 146097) as u64; // day of era [0, 146096] + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let y = if m <= 2 { y + 1 } else { y }; + + let hour = (time_of_day / 3600) as u8; + let minute = ((time_of_day % 3600) / 60) as u8; + let second = (time_of_day % 60) as u8; + + // Day of week: 1970-01-01 was Thursday (4). 1=Sun..7=Sat for CMOS. + let dow_zero = ((days % 7) + 4) % 7; // 0=Sun..6=Sat + let dow = dow_zero as u8 + 1; // 1=Sun..7=Sat + + let year_full = y as u16; + let century = (year_full / 100) as u8; + let year_2digit = (year_full % 100) as u8; + + Self { + seconds: to_bcd(second), + minutes: to_bcd(minute), + hours: to_bcd(hour), + day_of_week: to_bcd(dow), + day_of_month: to_bcd(d as u8), + month: to_bcd(m as u8), + year: to_bcd(year_2digit), + century: to_bcd(century), + } + } +} + +/// Host time snapshot, captured once at process start. +static CMOS_TIME: LazyLock = LazyLock::new(CmosTime::now); + +/// CMOS register read values. Time fields use the host-UTC snapshot; +/// everything else is static hardware description. +fn cmos_read(addr: u8) -> u8 { + let t = &*CMOS_TIME; + match addr { + 0x00 => t.seconds, + 0x02 => t.minutes, + 0x04 => t.hours, + 0x06 => t.day_of_week, + 0x07 => t.day_of_month, + 0x08 => t.month, + 0x09 => t.year, + 0x0A => 0x26, // Status A: no update in progress, 32.768 kHz + 0x0B => 0x02, // Status B: 24-hour, BCD mode + 0x0C => 0x00, // Status C: no interrupt source + 0x0D => 0x80, // Status D: battery OK + 0x0E => 0x00, // Diagnostic status + 0x0F => 0x00, // Shutdown status + 0x10 => 0x00, // Floppy drive type + 0x12 => 0x00, // Hard drive type + 0x15 => 0x80, // Base memory low byte (640KB = 0x0280) + 0x16 => 0x02, // Base memory high byte + 0x17 => 0x00, // Extended memory low (kernel uses E820) + 0x18 => 0x00, // Extended memory high + 0x32 => t.century, + _ => 0x00, + } +} + +/// Result of creating devices from a VmContext. +pub struct DeviceSetup { + /// The device manager. + pub devices: DeviceManager, + /// MMIO slots to include in the kernel command line. + pub mmio_slots: Vec, + /// Whether a root disk is present. + pub has_root_disk: bool, + /// Shared console output buffer (captures all serial output). + pub console_buffer: ConsoleBuffer, +} + +/// Centralized device manager for all emulated devices. +pub struct DeviceManager { + serial: Serial, + pub irq_chip: IrqChip, + pit: Pit, + cmos_addr: u8, + + /// Virtio-blk device (slot 0) — optional. + virtio_blk: Option>, + /// Virtio-vsock device (slot 1). + virtio_vsock: VirtioMmioDevice, + /// Virtio-9p device (slot 2) — optional. + virtio_9p: Option>, + /// Virtio-net device (slot 3) — optional. + virtio_net: Option>, + /// Second virtio-blk device (slot 4) — optional, for guest rootfs. + virtio_blk2: Option>, + /// Virtio-rng device (slot 5) — always present. + virtio_rng: VirtioMmioDevice, + /// Virtio-balloon device (slot 6) — always present. + virtio_balloon: VirtioMmioDevice, + + /// Diagnostic: count QUEUE_NOTIFY writes to blk devices. + blk_queue_notify_count: u64, + /// Diagnostic: count block I/O completions drained. + blk_completion_count: u64, + /// Diagnostic: count MMIO accesses to IOAPIC range. + ioapic_mmio_count: u64, + /// Diagnostic: count MMIO accesses to LAPIC range. + lapic_mmio_count: u64, + + /// Track whether we've requested an interrupt window. + window_requested: bool, + /// Last PIT tick timestamp. + last_tick: Instant, + /// Toggle state for port 0x61 bit 5 (PIT counter 2 output). + /// + /// Linux's `pit_calibrate_tsc()` loops reading port 0x61 waiting for + /// bit 5 to toggle. Without toggling, TSC calibration stalls forever. + port61_toggle: bool, + /// ACPI shutdown detected (PM1a_CNT S5 sleep type written). + shutdown_requested: bool, +} + +impl DeviceManager { + /// Create all devices from a VmContext configuration. + /// + /// Returns the device manager plus MMIO slot info for the kernel cmdline. + pub fn from_context(ctx: &VmContext) -> Result { + // Serial console with capture buffer. + let console_buffer: ConsoleBuffer = Arc::new(Mutex::new(Vec::new())); + let serial = if let Some(ref path) = ctx.console_output { + let file = File::create(path).map_err(|e| { + WkrunError::Device(format!( + "failed to create console output '{}': {}", + path.display(), + e + )) + })?; + let tee = TeeWriter { + inner: Box::new(file), + buffer: console_buffer.clone(), + }; + Serial::new(COM1_BASE, Box::new(tee)) + } else { + let tee = TeeWriter { + inner: Box::new(std::io::stdout()), + buffer: console_buffer.clone(), + }; + Serial::new(COM1_BASE, Box::new(tee)) + }; + + // Virtio-blk (slot 0) — first disk (container rootfs). + let has_root_disk = !ctx.disks.is_empty(); + let virtio_blk = if let Some(disk) = ctx.disks.first() { + let backend = open_disk_backend(&disk.path, disk.format, disk.read_only)?; + let blk = VirtioBlock::new(backend, disk.read_only); + Some(VirtioMmioDevice::new(blk)) + } else { + None + }; + + // Virtio-blk2 (slot 4) — second disk (guest rootfs), if present. + let virtio_blk2 = if let Some(disk) = ctx.disks.get(1) { + let backend = open_disk_backend(&disk.path, disk.format, disk.read_only)?; + let blk = VirtioBlock::new(backend, disk.read_only); + Some(VirtioMmioDevice::new(blk)) + } else { + None + }; + + // Virtio-vsock (slot 1) — always present. + let mut vsock_backend = VirtioVsock::new(GUEST_CID); + // Configure ports: listen=true creates Unix socket listener (host→guest), + // listen=false registers outbound target (guest→host). + for vp in &ctx.vsock_ports { + let socket_path = vp.host_path.to_string_lossy(); + if vp.listen { + let _ = vsock_backend.listen_on(vp.port, &socket_path); + } else { + vsock_backend.connect_to(vp.port, socket_path.to_string()); + } + } + let virtio_vsock = VirtioMmioDevice::new(vsock_backend); + + // Virtio-9p (slot 2) — optional, from fs_mounts. + let virtio_9p = ctx.fs_mounts.first().map(|mount| { + let p9 = Virtio9p::new(&mount.tag, mount.host_path.clone(), false); + VirtioMmioDevice::new(p9) + }); + + // Virtio-net (slot 3) — optional, from net_config. + let virtio_net = if let Some(ref net_cfg) = ctx.net_config { + let transport = Self::connect_net_transport(&net_cfg.socket_path)?; + let net = VirtioNet::new(net_cfg.mac, transport); + Some(VirtioMmioDevice::new(net)) + } else { + None + }; + + // Virtio-rng (slot 5) — always present. + let virtio_rng = VirtioMmioDevice::new(VirtioRng::new()); + + // Virtio-balloon (slot 6) — always present. + let virtio_balloon = VirtioMmioDevice::new(VirtioBalloon::new()); + + // Build MMIO slots for kernel cmdline. + let mmio_slots = vec![ + MmioSlot { + index: 0, + active: virtio_blk.is_some(), + }, + MmioSlot { + index: 1, + active: true, + }, // vsock always active + MmioSlot { + index: 2, + active: virtio_9p.is_some(), + }, + MmioSlot { + index: 3, + active: virtio_net.is_some(), + }, + MmioSlot { + index: 4, + active: virtio_blk2.is_some(), + }, + MmioSlot { + index: 5, + active: true, + }, // rng always active + MmioSlot { + index: 6, + active: true, + }, // balloon always active + ]; + + let devices = DeviceManager { + serial, + irq_chip: IrqChip::new(ctx.num_vcpus), + pit: Pit::new(), + cmos_addr: 0, + virtio_blk, + virtio_vsock, + virtio_9p, + virtio_net, + virtio_blk2, + virtio_rng, + virtio_balloon, + blk_queue_notify_count: 0, + blk_completion_count: 0, + ioapic_mmio_count: 0, + lapic_mmio_count: 0, + window_requested: false, + last_tick: Instant::now(), + port61_toggle: false, + shutdown_requested: false, + }; + + Ok(DeviceSetup { + devices, + mmio_slots, + has_root_disk, + console_buffer, + }) + } + + /// Handle an I/O port output (write) from the guest. + /// + /// Returns `true` if skip_instruction should be called after. + pub fn handle_io_out(&mut self, port: u16, size: u8, data: u32) { + if self.serial.handles_port(port) { + self.serial.io_write(port, size, data); + if self.serial.has_interrupt() { + self.irq_chip.raise_irq(4); + } + } else if self.irq_chip.pic.handles_port(port) { + log::trace!("PIC write: port={:#X} data={:#X}", port, data as u8); + self.irq_chip.pic.write_port(port, data as u8); + } else if self.pit.handles_port(port) { + log::trace!("PIT write: port={:#X} data={:#X}", port, data as u8); + self.pit.write_port(port, data as u8); + } else if port == PM1A_CNT_BLK { + // ACPI PM1a control register: detect S5 shutdown. + // SLP_EN = bit 13, SLP_TYP = bits 12:10. + let slp_en = (data >> 13) & 1; + let slp_typ = (data >> 10) & 0x7; + if slp_en == 1 && slp_typ == 5 { + log::info!("ACPI S5 shutdown detected (PM1a_CNT={:#X})", data); + self.shutdown_requested = true; + } + } else if port == 0x70 { + self.cmos_addr = (data as u8) & 0x7F; + } + // Ignore writes to other ports (PS/2, etc.). + } + + /// Handle an I/O port input (read) from the guest. + /// + /// Returns the data to inject into the guest register. + pub fn handle_io_in(&mut self, port: u16, size: u8) -> u32 { + if self.serial.handles_port(port) { + let val = self.serial.io_read(port, size); + if self.serial.has_interrupt() { + self.irq_chip.raise_irq(4); + } + val + } else if self.irq_chip.pic.handles_port(port) { + self.irq_chip.pic.read_port(port) as u32 + } else if self.pit.handles_port(port) { + self.pit.read_port(port) as u32 + } else if port == 0x71 { + cmos_read(self.cmos_addr) as u32 + } else if (PM1A_EVT_BLK..PM1A_EVT_BLK + 4).contains(&port) { + 0x00 // PM1a event: no events pending + } else if (PM1A_CNT_BLK..PM1A_CNT_BLK + 2).contains(&port) { + 0x00 // PM1a control: clear state + } else if (0xCF8..=0xCFF).contains(&port) { + 0xFFFF_FFFF // PCI config: no devices. + } else if port == 0x61 { + // System control port B: toggle bit 5 (PIT counter 2 output). + // + // Linux's `pit_calibrate_tsc()` reads this port in a tight loop + // waiting for bit 5 to change. A static value causes an infinite + // loop that stalls kernel boot. Toggling on each read lets the + // calibration complete. + self.port61_toggle = !self.port61_toggle; + if self.port61_toggle { + 0x20 + } else { + 0x00 + } + } else if port == 0x92 { + 0x02 // System control port A: A20 enabled. + } else if port == 0x60 || port == 0x64 { + // i8042 PS/2 controller: data (0x60) and status (0x64). + // + // Return 0x00 = both buffers empty, no pending data. + // Without this, the default 0xFF makes the i8042 driver spin in + // udelay() loops waiting for the input buffer to drain. + 0x00 + } else { + 0xFF // Default: no device. + } + } + + /// Handle an MMIO read from the guest. + /// + /// Returns the data to inject into the destination register. + /// `vcpu_id` selects which LAPIC to read from (each vCPU has its own). + pub fn handle_mmio_read(&mut self, vcpu_id: u8, address: u64, size: u8) -> u64 { + // Check IOAPIC/LAPIC ranges first. + if let Some(val) = self.irq_chip.handle_mmio_read(vcpu_id, address, size) { + use super::super::memory::{ + IOAPIC_MMIO_BASE, IOAPIC_MMIO_SIZE, LAPIC_MMIO_BASE, LAPIC_MMIO_SIZE, + }; + if address >= IOAPIC_MMIO_BASE && address < IOAPIC_MMIO_BASE + IOAPIC_MMIO_SIZE { + self.ioapic_mmio_count += 1; + } else if address >= LAPIC_MMIO_BASE && address < LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE { + self.lapic_mmio_count += 1; + } + return val as u64; + } + + let blk_offset = address.wrapping_sub(mmio_base_for_slot(0)); + let vsock_offset = address.wrapping_sub(mmio_base_for_slot(1)); + let p9_offset = address.wrapping_sub(mmio_base_for_slot(2)); + let net_offset = address.wrapping_sub(mmio_base_for_slot(3)); + let blk2_offset = address.wrapping_sub(mmio_base_for_slot(4)); + let rng_offset = address.wrapping_sub(mmio_base_for_slot(5)); + let balloon_offset = address.wrapping_sub(mmio_base_for_slot(6)); + + if blk_offset < MMIO_SLOT_SIZE { + if let Some(ref dev) = self.virtio_blk { + dev.read(blk_offset, size) as u64 + } else { + 0 + } + } else if vsock_offset < MMIO_SLOT_SIZE { + self.virtio_vsock.read(vsock_offset, size) as u64 + } else if p9_offset < MMIO_SLOT_SIZE { + if let Some(ref dev) = self.virtio_9p { + dev.read(p9_offset, size) as u64 + } else { + 0 + } + } else if net_offset < MMIO_SLOT_SIZE { + if let Some(ref dev) = self.virtio_net { + dev.read(net_offset, size) as u64 + } else { + 0 + } + } else if blk2_offset < MMIO_SLOT_SIZE { + if let Some(ref dev) = self.virtio_blk2 { + dev.read(blk2_offset, size) as u64 + } else { + 0 + } + } else if rng_offset < MMIO_SLOT_SIZE { + self.virtio_rng.read(rng_offset, size) as u64 + } else if balloon_offset < MMIO_SLOT_SIZE { + self.virtio_balloon.read(balloon_offset, size) as u64 + } else { + 0 + } + } + + /// Handle an MMIO write from the guest. + /// + /// `vcpu_id` selects which LAPIC to write to (each vCPU has its own). + /// Returns the IPI action if the write was to the LAPIC ICR register. + pub fn handle_mmio_write( + &mut self, + vcpu_id: u8, + address: u64, + size: u8, + data: u64, + mem: &dyn GuestMemoryAccessor, + ) -> IpiAction { + // Check IOAPIC/LAPIC ranges first. + let result = self + .irq_chip + .handle_mmio_write(vcpu_id, address, size as u8, data as u32); + if result.handled { + use super::super::memory::{ + IOAPIC_MMIO_BASE, IOAPIC_MMIO_SIZE, LAPIC_MMIO_BASE, LAPIC_MMIO_SIZE, + }; + if address >= IOAPIC_MMIO_BASE && address < IOAPIC_MMIO_BASE + IOAPIC_MMIO_SIZE { + self.ioapic_mmio_count += 1; + } else if address >= LAPIC_MMIO_BASE && address < LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE { + self.lapic_mmio_count += 1; + } + return result.ipi_action; + } + + let blk_offset = address.wrapping_sub(mmio_base_for_slot(0)); + let vsock_offset = address.wrapping_sub(mmio_base_for_slot(1)); + let p9_offset = address.wrapping_sub(mmio_base_for_slot(2)); + let net_offset = address.wrapping_sub(mmio_base_for_slot(3)); + let blk2_offset = address.wrapping_sub(mmio_base_for_slot(4)); + let rng_offset = address.wrapping_sub(mmio_base_for_slot(5)); + let balloon_offset = address.wrapping_sub(mmio_base_for_slot(6)); + + if blk_offset < MMIO_SLOT_SIZE { + if blk_offset == 0x050 { + self.blk_queue_notify_count += 1; + } + if let Some(ref mut dev) = self.virtio_blk { + if dev.write(blk_offset, data as u32, size, mem) { + self.irq_chip.raise_irq(irq_for_slot(0)); + } + } + } else if vsock_offset < MMIO_SLOT_SIZE { + if self + .virtio_vsock + .write(vsock_offset, data as u32, size, mem) + { + self.irq_chip.raise_irq(irq_for_slot(1)); + } + } else if p9_offset < MMIO_SLOT_SIZE { + if let Some(ref mut dev) = self.virtio_9p { + if dev.write(p9_offset, data as u32, size, mem) { + self.irq_chip.raise_irq(irq_for_slot(2)); + } + } + } else if net_offset < MMIO_SLOT_SIZE { + if let Some(ref mut dev) = self.virtio_net { + if dev.write(net_offset, data as u32, size, mem) { + self.irq_chip.raise_irq(irq_for_slot(3)); + } + } + } else if blk2_offset < MMIO_SLOT_SIZE { + if blk2_offset == 0x050 { + self.blk_queue_notify_count += 1; + } + if let Some(ref mut dev) = self.virtio_blk2 { + if dev.write(blk2_offset, data as u32, size, mem) { + self.irq_chip.raise_irq(irq_for_slot(4)); + } + } + } else if rng_offset < MMIO_SLOT_SIZE { + if self.virtio_rng.write(rng_offset, data as u32, size, mem) { + self.irq_chip.raise_irq(irq_for_slot(5)); + } + } else if balloon_offset < MMIO_SLOT_SIZE { + if self + .virtio_balloon + .write(balloon_offset, data as u32, size, mem) + { + self.irq_chip.raise_irq(irq_for_slot(6)); + } + } + IpiAction::None + } + + /// Start async block I/O workers for virtio-blk devices (Plan B: WHPX-safe). + /// + /// Workers never access guest memory — all guest memory I/O happens + /// on the vCPU thread (in queue_notify and drain_completions). + pub fn start_blk_workers(&mut self) { + if let Some(ref mut dev) = self.virtio_blk { + dev.backend_mut().start_worker("blk-worker-0"); + } + if let Some(ref mut dev) = self.virtio_blk2 { + dev.backend_mut().start_worker("blk-worker-1"); + } + } + + /// Stop async block I/O workers. + /// + /// Called during shutdown. Also called by Drop if not explicitly called. + pub fn stop_blk_workers(&mut self) { + if let Some(ref mut dev) = self.virtio_blk { + dev.backend_mut().stop_worker(); + } + if let Some(ref mut dev) = self.virtio_blk2 { + dev.backend_mut().stop_worker(); + } + } + + /// Tick the PIT timer based on wall clock time and poll devices. + /// + /// Call this at the top of each vCPU run loop iteration. + /// `vcpu_id` selects which LAPIC timer to tick (BSP should be 0). + pub fn tick_and_poll(&mut self, vcpu_id: u8, mem: &dyn GuestMemoryAccessor) { + // Tick PIT. + let now = Instant::now(); + let elapsed_ns = now.duration_since(self.last_tick).as_nanos() as u64; + self.last_tick = now; + + if elapsed_ns > 0 { + let fires = self.pit.tick(elapsed_ns); + for _ in 0..fires { + self.irq_chip.raise_irq(0); + } + } + + // LAPIC timers are now ticked per-vCPU in the runner loop (lock-free). + // This eliminates cross-vCPU contention on tick_and_poll(). + // Suppress unused variable — vcpu_id was the original single-vCPU target. + let _ = vcpu_id; + + // Drain async block I/O completions. + if let Some(ref mut dev) = self.virtio_blk { + if dev.poll_backend(mem) { + self.blk_completion_count += 1; + self.irq_chip.raise_irq(irq_for_slot(0)); + } + } + if let Some(ref mut dev) = self.virtio_blk2 { + if dev.poll_backend(mem) { + self.blk_completion_count += 1; + self.irq_chip.raise_irq(irq_for_slot(4)); + } + } + + // Poll vsock for host-initiated data. + if self.virtio_vsock.poll(mem) { + log::debug!("vsock poll raised IRQ {}", irq_for_slot(1)); + self.irq_chip.raise_irq(irq_for_slot(1)); + } + + // Poll net for incoming frames. + if let Some(ref mut dev) = self.virtio_net { + if dev.poll(mem) { + self.irq_chip.raise_irq(irq_for_slot(3)); + } + } + } + + /// Connect to the userspace networking proxy and return a transport. + /// + /// Connects via Unix stream socket on all platforms. + fn connect_net_transport( + socket_path: &Path, + ) -> Result>> { + #[cfg(unix)] + { + let stream = std::os::unix::net::UnixStream::connect(socket_path).map_err(|e| { + WkrunError::Device(format!( + "failed to connect to net socket '{}': {}", + socket_path.display(), + e + )) + })?; + let transport = super::virtio::net::UnixStreamTransport::new(stream).map_err(|e| { + WkrunError::Device(format!("failed to configure net socket: {}", e)) + })?; + Ok(Some(Box::new(transport))) + } + #[cfg(windows)] + { + let stream = uds_windows::UnixStream::connect(socket_path).map_err(|e| { + WkrunError::Device(format!( + "failed to connect to net socket '{}': {}", + socket_path.display(), + e + )) + })?; + let transport = super::virtio::net::UdsTransport::new(stream).map_err(|e| { + WkrunError::Device(format!("failed to configure net socket: {}", e)) + })?; + Ok(Some(Box::new(transport))) + } + } + + /// Return block I/O diagnostic counters: (queue_notify_count, completion_count). + pub fn blk_stats(&self) -> (u64, u64) { + (self.blk_queue_notify_count, self.blk_completion_count) + } + + /// Get IOAPIC/LAPIC MMIO access counts for diagnostics. + pub fn apic_mmio_stats(&self) -> (u64, u64) { + (self.ioapic_mmio_count, self.lapic_mmio_count) + } + + /// Whether an ACPI S5 shutdown was detected. + pub fn shutdown_requested(&self) -> bool { + self.shutdown_requested + } + + /// Whether an interrupt window has been requested. + pub fn window_requested(&self) -> bool { + self.window_requested + } + + /// Set the interrupt window requested flag. + pub fn set_window_requested(&mut self, requested: bool) { + self.window_requested = requested; + } + + /// Get per-vCPU LAPIC references for the runner fast path. + /// + /// Each ref can be locked independently of the DeviceManager lock, + /// eliminating cross-vCPU contention on LAPIC MMIO reads. + pub fn get_lapic_refs(&self) -> Vec>> { + (0..self.irq_chip.num_vcpus()) + .map(|i| self.irq_chip.get_lapic_ref(i as u32)) + .collect() + } + + /// Get per-vCPU shared APIC states for lock-free cross-vCPU interrupt delivery. + pub fn get_shared_states(&self) -> Vec> { + (0..self.irq_chip.num_vcpus()) + .map(|i| self.irq_chip.get_shared_state(i as u32)) + .collect() + } +} + +/// Create a `DeviceManager` from explicit components (for testing). +pub fn device_manager_with_serial(serial: Serial) -> DeviceManager { + let vsock_backend = VirtioVsock::new(GUEST_CID); + DeviceManager { + serial, + irq_chip: IrqChip::new(1), + pit: Pit::new(), + cmos_addr: 0, + virtio_blk: None, + virtio_vsock: VirtioMmioDevice::new(vsock_backend), + virtio_9p: None, + virtio_net: None, + virtio_blk2: None, + virtio_rng: VirtioMmioDevice::new(VirtioRng::new()), + virtio_balloon: VirtioMmioDevice::new(VirtioBalloon::new()), + blk_queue_notify_count: 0, + blk_completion_count: 0, + ioapic_mmio_count: 0, + lapic_mmio_count: 0, + window_requested: false, + last_tick: Instant::now(), + port61_toggle: false, + shutdown_requested: false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use std::path::PathBuf; + use std::sync::{Arc, Mutex}; + + /// Capture buffer for serial output in tests. + #[derive(Clone)] + struct CaptureSink { + buf: Arc>>, + } + + impl CaptureSink { + fn new() -> Self { + CaptureSink { + buf: Arc::new(Mutex::new(Vec::new())), + } + } + + fn contents(&self) -> Vec { + self.buf.lock().unwrap().clone() + } + } + + impl Write for CaptureSink { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buf.lock().unwrap().extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + fn make_test_devices() -> DeviceManager { + let serial = Serial::new(COM1_BASE, Box::new(std::io::sink())); + device_manager_with_serial(serial) + } + + #[test] + fn test_io_out_serial_write() { + let sink = CaptureSink::new(); + let serial = Serial::new(COM1_BASE, Box::new(sink.clone())); + let mut dm = device_manager_with_serial(serial); + + // Write 'A' to THR (port 0x3F8). + dm.handle_io_out(0x3F8, 1, b'A' as u32); + assert_eq!(sink.contents(), b"A"); + } + + #[test] + fn test_io_in_serial_lsr() { + let mut dm = make_test_devices(); + // Read LSR (port 0x3FD) — should report transmitter empty. + let lsr = dm.handle_io_in(0x3FD, 1); + // LSR bit 5 (THRE) and bit 6 (TEMT) should be set. + assert_ne!(lsr & 0x60, 0); + } + + #[test] + fn test_io_in_pci_config_no_devices() { + let mut dm = make_test_devices(); + let data = dm.handle_io_in(0xCF8, 4); + assert_eq!(data, 0xFFFF_FFFF); + } + + #[test] + fn test_io_in_system_control_port_b_toggles() { + let mut dm = make_test_devices(); + // Port 0x61 bit 5 toggles on each read. + let first = dm.handle_io_in(0x61, 1); + let second = dm.handle_io_in(0x61, 1); + assert_ne!(first, second, "bit 5 should toggle"); + let third = dm.handle_io_in(0x61, 1); + assert_eq!(first, third, "should cycle back"); + } + + #[test] + fn test_io_in_system_control_port_a() { + let mut dm = make_test_devices(); + assert_eq!(dm.handle_io_in(0x92, 1), 0x02); + } + + #[test] + fn test_io_in_i8042_status_empty() { + let mut dm = make_test_devices(); + // Port 0x64 (i8042 status): both buffers empty. + assert_eq!(dm.handle_io_in(0x64, 1), 0x00); + // Port 0x60 (i8042 data): no data. + assert_eq!(dm.handle_io_in(0x60, 1), 0x00); + } + + #[test] + fn test_io_in_unknown_port() { + let mut dm = make_test_devices(); + assert_eq!(dm.handle_io_in(0x999, 1), 0xFF); + } + + #[test] + fn test_to_bcd() { + assert_eq!(to_bcd(0), 0x00); + assert_eq!(to_bcd(9), 0x09); + assert_eq!(to_bcd(10), 0x10); + assert_eq!(to_bcd(26), 0x26); + assert_eq!(to_bcd(59), 0x59); + assert_eq!(to_bcd(99), 0x99); + } + + #[test] + fn test_cmos_time_now_is_reasonable() { + let t = CmosTime::now(); + // Year should be 2025–2099 in BCD (0x25..0x99). + assert!(t.year >= 0x25, "year BCD too low: {:#04x}", t.year); + // Month 1..12 in BCD (0x01..0x12). + assert!( + t.month >= 0x01 && t.month <= 0x12, + "month: {:#04x}", + t.month + ); + // Day 1..31 in BCD. + assert!(t.day_of_month >= 0x01 && t.day_of_month <= 0x31); + // Hours 0..23 in BCD. + assert!(t.hours <= 0x23); + // Century should be 0x20. + assert_eq!(t.century, 0x20); + } + + #[test] + fn test_cmos_read_via_io() { + let mut dm = make_test_devices(); + // Select CMOS register 0x09 (year). + dm.handle_io_out(0x70, 1, 0x09); + let year = dm.handle_io_in(0x71, 1); + // Year must be valid BCD (>= 0x25 for 2025+). + assert!(year >= 0x25, "year BCD: {:#04x}", year); + } + + #[test] + fn test_cmos_read_battery_ok() { + let mut dm = make_test_devices(); + dm.handle_io_out(0x70, 1, 0x0D); + let status_d = dm.handle_io_in(0x71, 1); + assert_eq!(status_d, 0x80); + } + + #[test] + fn test_mmio_read_no_blk_device() { + let dm = make_test_devices(); + // Read from virtio-blk slot when no device present. + let data = dm.handle_mmio_read(0, mmio_base_for_slot(0), 4); + assert_eq!(data, 0); + } + + #[test] + fn test_mmio_read_vsock_magic() { + let dm = make_test_devices(); + // Read virtio magic from vsock MMIO slot. + let magic = dm.handle_mmio_read(0, mmio_base_for_slot(1), 4); + assert_eq!(magic, 0x7472_6976); // "virt" in LE. + } + + #[test] + fn test_mmio_read_vsock_device_id() { + let dm = make_test_devices(); + // Device ID is at offset 0x008. + let device_id = dm.handle_mmio_read(0, mmio_base_for_slot(1) + 0x008, 4); + assert_eq!(device_id, 19); // vsock device ID. + } + + #[test] + fn test_mmio_read_out_of_range() { + let dm = make_test_devices(); + // Read from an address that doesn't belong to any device. + let data = dm.handle_mmio_read(0, 0xE000_0000, 4); + assert_eq!(data, 0); + } + + #[test] + fn test_window_requested_default() { + let dm = make_test_devices(); + assert!(!dm.window_requested()); + } + + #[test] + fn test_window_requested_toggle() { + let mut dm = make_test_devices(); + dm.set_window_requested(true); + assert!(dm.window_requested()); + dm.set_window_requested(false); + assert!(!dm.window_requested()); + } + + #[test] + fn test_tee_writer() { + let inner_buf = Arc::new(Mutex::new(Vec::new())); + let capture_buf: ConsoleBuffer = Arc::new(Mutex::new(Vec::new())); + + let inner = CaptureSink { + buf: inner_buf.clone(), + }; + let mut tee = super::TeeWriter { + inner: Box::new(inner), + buffer: capture_buf.clone(), + }; + + tee.write_all(b"Hello").unwrap(); + tee.write_all(b", VM!").unwrap(); + tee.flush().unwrap(); + + // Both sinks should have the same content. + assert_eq!(inner_buf.lock().unwrap().as_slice(), b"Hello, VM!"); + assert_eq!(capture_buf.lock().unwrap().as_slice(), b"Hello, VM!"); + } + + #[test] + fn test_console_buffer_store_and_get() { + let buf: ConsoleBuffer = Arc::new(Mutex::new(Vec::new())); + buf.lock().unwrap().extend_from_slice(b"test output"); + + let ctx_id = 90000; // Unique ID to avoid conflicts. + super::store_console_buffer(ctx_id, buf); + + let output = super::get_console_output(ctx_id).unwrap(); + assert_eq!(output, b"test output"); + + // Cleanup. + super::remove_console_buffer(ctx_id); + assert!(super::get_console_output(ctx_id).is_none()); + } + + #[test] + fn test_console_buffer_not_found() { + assert!(super::get_console_output(89999).is_none()); + } + + #[test] + fn test_from_context_has_console_buffer() { + let ctx = VmContext::default_for_test(); + let setup = DeviceManager::from_context(&ctx).unwrap(); + // Buffer should be empty initially. + assert!(setup.console_buffer.lock().unwrap().is_empty()); + } + + #[test] + fn test_from_context_minimal() { + let ctx = VmContext::default_for_test(); + let setup = DeviceManager::from_context(&ctx).unwrap(); + assert!(!setup.has_root_disk); + // Slot 0 (blk) inactive, slot 1 (vsock) active, slot 2 (9p) inactive. + assert!(!setup.mmio_slots[0].active); + assert!(setup.mmio_slots[1].active); + assert!(!setup.mmio_slots[2].active); + } + + #[test] + fn test_from_context_with_fs_mount() { + let mut ctx = VmContext::default_for_test(); + ctx.fs_mounts.push(super::super::super::context::FsMount { + tag: "test".to_string(), + host_path: PathBuf::from("/tmp"), + }); + let setup = DeviceManager::from_context(&ctx).unwrap(); + // Slot 2 (9p) should now be active. + assert!(setup.mmio_slots[2].active); + } + + #[test] + fn test_from_context_net_slot_inactive_by_default() { + let ctx = VmContext::default_for_test(); + let setup = DeviceManager::from_context(&ctx).unwrap(); + // Slot 3 (net) should be inactive when no net_config. + assert!(!setup.mmio_slots[3].active); + } + + #[test] + fn test_mmio_read_no_net_device() { + let dm = make_test_devices(); + // Read from virtio-net slot when no device present. + let data = dm.handle_mmio_read(0, mmio_base_for_slot(3), 4); + assert_eq!(data, 0); + } + + #[test] + fn test_acpi_shutdown_not_requested_initially() { + let dm = make_test_devices(); + assert!(!dm.shutdown_requested()); + } + + #[test] + fn test_acpi_s5_shutdown_detected() { + let mut dm = make_test_devices(); + // Write SLP_TYP=5, SLP_EN=1 → bits 12:10 = 0b101, bit 13 = 1. + // Value = (1 << 13) | (5 << 10) = 0x2000 | 0x1400 = 0x3400. + dm.handle_io_out(0x604, 2, 0x3400); + assert!(dm.shutdown_requested()); + } + + #[test] + fn test_acpi_non_s5_write_ignored() { + let mut dm = make_test_devices(); + // SLP_EN=1, SLP_TYP=0 → not S5. + dm.handle_io_out(0x604, 2, 0x2000); + assert!(!dm.shutdown_requested()); + } + + #[test] + fn test_acpi_pm1a_evt_read_zero() { + let mut dm = make_test_devices(); + assert_eq!(dm.handle_io_in(0x600, 4), 0x00); + } + + #[test] + fn test_acpi_pm1a_cnt_read_zero() { + let mut dm = make_test_devices(); + assert_eq!(dm.handle_io_in(0x604, 2), 0x00); + } + + // --- virtio-rng (slot 5) --- + + #[test] + fn test_mmio_read_rng_magic() { + let dm = make_test_devices(); + let magic = dm.handle_mmio_read(0, mmio_base_for_slot(5), 4); + assert_eq!(magic, 0x7472_6976); // "virt" in LE. + } + + #[test] + fn test_mmio_read_rng_device_id() { + let dm = make_test_devices(); + let device_id = dm.handle_mmio_read(0, mmio_base_for_slot(5) + 0x008, 4); + assert_eq!(device_id, 4); // VIRTIO_ID_RNG + } + + // --- virtio-balloon (slot 6) --- + + #[test] + fn test_mmio_read_balloon_magic() { + let dm = make_test_devices(); + let magic = dm.handle_mmio_read(0, mmio_base_for_slot(6), 4); + assert_eq!(magic, 0x7472_6976); // "virt" in LE. + } + + #[test] + fn test_mmio_read_balloon_device_id() { + let dm = make_test_devices(); + let device_id = dm.handle_mmio_read(0, mmio_base_for_slot(6) + 0x008, 4); + assert_eq!(device_id, 5); // VIRTIO_ID_BALLOON + } + + #[test] + fn test_from_context_rng_and_balloon_always_active() { + let ctx = VmContext::default_for_test(); + let setup = DeviceManager::from_context(&ctx).unwrap(); + // Slots 5 (rng) and 6 (balloon) should always be active. + assert!(setup.mmio_slots[5].active, "rng slot should be active"); + assert!(setup.mmio_slots[6].active, "balloon slot should be active"); + } +} diff --git a/src/vmm/src/windows/devices/mod.rs b/src/vmm/src/windows/devices/mod.rs new file mode 100644 index 000000000..f54460b49 --- /dev/null +++ b/src/vmm/src/windows/devices/mod.rs @@ -0,0 +1,10 @@ +//! Device emulation for the guest VM. + +pub mod ioapic; +pub mod irq_chip; +pub mod lapic; +pub mod manager; +pub mod pic; +pub mod pit; +pub mod serial; +pub mod virtio; diff --git a/src/vmm/src/windows/devices/pic.rs b/src/vmm/src/windows/devices/pic.rs new file mode 100644 index 000000000..8838848d1 --- /dev/null +++ b/src/vmm/src/windows/devices/pic.rs @@ -0,0 +1,809 @@ +//! 8259 PIC (Programmable Interrupt Controller) emulation. +//! +//! Emulates a dual 8259 PIC (master + slave) for legacy interrupt routing. +//! +//! Master PIC: I/O ports 0x20-0x21, handles IRQs 0-7 +//! Slave PIC: I/O ports 0xA0-0xA1, handles IRQs 8-15 +//! Slave is connected to master IRQ 2 (cascade). +//! +//! The Linux kernel in PIC mode (noapic nolapic) programs the PICs to: +//! - Master: vector base 0x20 (IRQs 0-7 → vectors 0x20-0x27) +//! - Slave: vector base 0x28 (IRQs 8-15 → vectors 0x28-0x2F) + +use super::super::vcpu::IoHandler; + +/// Master PIC command port. +pub const PIC_MASTER_CMD: u16 = 0x20; +/// Master PIC data port. +pub const PIC_MASTER_DATA: u16 = 0x21; +/// Slave PIC command port. +pub const PIC_SLAVE_CMD: u16 = 0xA0; +/// Slave PIC data port. +pub const PIC_SLAVE_DATA: u16 = 0xA1; + +/// Cascade IRQ (slave connected to master IRQ 2). +const CASCADE_IRQ: u8 = 2; + +/// State for a single 8259 PIC chip. +#[derive(Debug)] +struct PicChip { + /// Interrupt Request Register — pending interrupt requests. + irr: u8, + /// In-Service Register — interrupts currently being serviced. + isr: u8, + /// Interrupt Mask Register — masked (disabled) interrupts. + imr: u8, + /// Vector base (aligned to 8, set by ICW2). + vector_base: u8, + /// ICW initialization state machine. + init_state: InitState, + /// Whether ICW4 is needed (from ICW1 bit 0). + icw4_needed: bool, + /// Whether to read ISR (true) or IRR (false) on command port read. + read_isr: bool, + /// Auto-EOI mode (from ICW4 bit 1). + auto_eoi: bool, +} + +/// ICW initialization state machine. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InitState { + /// Normal operation (not in initialization). + Ready, + /// Waiting for ICW2 (vector base). + WaitIcw2, + /// Waiting for ICW3 (cascade configuration). + WaitIcw3, + /// Waiting for ICW4 (mode). + WaitIcw4, +} + +impl PicChip { + fn new() -> Self { + PicChip { + irr: 0, + isr: 0, + imr: 0xFF, // All interrupts masked initially + vector_base: 0, + init_state: InitState::Ready, + icw4_needed: false, + read_isr: false, + auto_eoi: false, + } + } + + /// Write to the command port (port 0x20 or 0xA0). + fn write_command(&mut self, data: u8) { + if data & 0x10 != 0 { + // ICW1: bit 4 = 1 → start initialization sequence. + self.icw4_needed = data & 0x01 != 0; + self.init_state = InitState::WaitIcw2; + // Reset during init. + self.isr = 0; + self.irr = 0; + self.imr = 0; + self.auto_eoi = false; + self.read_isr = false; + } else if data & 0x08 != 0 { + // OCW3: bit 3 = 1. + if data & 0x02 != 0 { + // Read register command. + self.read_isr = data & 0x01 != 0; + } + } else { + // OCW2: End of Interrupt. + let is_eoi = data & 0x20 != 0; + let is_specific = data & 0x40 != 0; + if is_eoi { + if is_specific { + // Specific EOI: clear specific ISR bit. + let irq = data & 0x07; + self.isr &= !(1 << irq); + } else { + // Non-specific EOI: clear highest-priority ISR bit. + for i in 0..8u8 { + if self.isr & (1 << i) != 0 { + self.isr &= !(1 << i); + break; + } + } + } + } + } + } + + /// Write to the data port (port 0x21 or 0xA1). + fn write_data(&mut self, data: u8) { + match self.init_state { + InitState::WaitIcw2 => { + // ICW2: vector base (upper 5 bits used, lower 3 are IRQ number). + self.vector_base = data & 0xF8; + self.init_state = InitState::WaitIcw3; + } + InitState::WaitIcw3 => { + // ICW3: cascade configuration (we accept but don't use the value). + if self.icw4_needed { + self.init_state = InitState::WaitIcw4; + } else { + self.init_state = InitState::Ready; + } + } + InitState::WaitIcw4 => { + // ICW4: mode configuration. + self.auto_eoi = data & 0x02 != 0; + self.init_state = InitState::Ready; + } + InitState::Ready => { + // OCW1: set interrupt mask register. + self.imr = data; + } + } + } + + /// Read from the command port. + fn read_command(&self) -> u8 { + if self.read_isr { + self.isr + } else { + self.irr + } + } + + /// Read from the data port. + fn read_data(&self) -> u8 { + self.imr + } + + /// Raise an interrupt request on this chip (local IRQ 0-7). + fn raise_irq(&mut self, irq: u8) { + self.irr |= 1 << (irq & 7); + } + + /// Clear an interrupt request (edge-triggered reset). + fn clear_irq(&mut self, irq: u8) { + self.irr &= !(1 << (irq & 7)); + } + + /// Get the highest-priority pending (unmasked, deliverable) IRQ, if any. + /// + /// Implements proper 8259A priority masking: when an interrupt is + /// in-service, all equal-or-lower-priority interrupts are blocked. + /// IRQ 0 has highest priority, IRQ 7 has lowest (default fixed + /// priority mode). + fn pending_irq(&self) -> Option { + let requested = self.irr & !self.imr; + if requested == 0 { + return None; + } + // Find the highest-priority (lowest-numbered) in-service IRQ. + // All IRQs at that level or lower are blocked. + let priority_ceiling = (0..8u8).find(|&i| self.isr & (1 << i) != 0); + match priority_ceiling { + Some(ceil) => { + // Only IRQs with higher priority (lower number) than the + // in-service IRQ can be delivered. + (0..ceil).find(|&i| requested & (1 << i) != 0) + } + None => { + // No interrupt in-service — deliver highest priority pending. + (0..8u8).find(|&i| requested & (1 << i) != 0) + } + } + } + + /// Acknowledge the highest-priority pending interrupt. + /// Moves the IRQ from IRR to ISR and returns the vector. + fn acknowledge(&mut self) -> Option { + if let Some(irq) = self.pending_irq() { + self.irr &= !(1 << irq); + if self.auto_eoi { + // Auto-EOI: don't set ISR. + } else { + self.isr |= 1 << irq; + } + Some(self.vector_base + irq) + } else { + None + } + } +} + +/// Dual 8259 PIC (master + slave). +pub struct Pic { + master: PicChip, + slave: PicChip, +} + +impl Default for Pic { + fn default() -> Self { + Self::new() + } +} + +impl Pic { + /// Create a new dual PIC with default state (all masked). + pub fn new() -> Self { + Pic { + master: PicChip::new(), + slave: PicChip::new(), + } + } + + /// Raise an interrupt request (IRQ 0-15). + /// + /// IRQs 0-7 go to the master PIC, IRQs 8-15 go to the slave PIC. + /// When a slave IRQ is raised, the cascade line (master IRQ 2) is also raised. + pub fn raise_irq(&mut self, irq: u8) { + if irq < 8 { + self.master.raise_irq(irq); + } else { + self.slave.raise_irq(irq - 8); + // Slave cascades through master IRQ 2. + self.master.raise_irq(CASCADE_IRQ); + } + } + + /// Clear an interrupt request (for edge-triggered mode). + pub fn clear_irq(&mut self, irq: u8) { + if irq < 8 { + self.master.clear_irq(irq); + } else { + self.slave.clear_irq(irq - 8); + // If no more slave IRQs pending, clear cascade on master. + if self.slave.pending_irq().is_none() { + self.master.clear_irq(CASCADE_IRQ); + } + } + } + + /// Check if there are any pending (unmasked, deliverable) interrupts. + pub fn has_pending(&self) -> bool { + self.master.pending_irq().is_some() + } + + /// Acknowledge the highest-priority pending interrupt. + /// + /// Returns the interrupt vector to deliver to the CPU, or None if + /// no interrupts are pending. + pub fn acknowledge(&mut self) -> Option { + if let Some(master_irq) = self.master.pending_irq() { + if master_irq == CASCADE_IRQ { + // Cascade: try to acknowledge slave first. + let vector = self.slave.acknowledge(); + if vector.is_some() { + // Slave had a real IRQ — acknowledge cascade on master. + self.master.acknowledge(); + // If no more slave IRQs pending, clear cascade IRR. + if self.slave.pending_irq().is_none() { + self.master.clear_irq(CASCADE_IRQ); + } + } else { + // Spurious cascade: slave has no deliverable IRQ. + // Clear cascade IRR without setting ISR — otherwise + // ISR bit 2 would be permanently stuck (guest never + // sends EOI for an interrupt it didn't receive). + self.master.clear_irq(CASCADE_IRQ); + } + vector + } else { + self.master.acknowledge() + } + } else { + None + } + } + + /// Get master PIC state for diagnostics: (IRR, ISR, IMR, vector_base). + pub fn master_state(&self) -> (u8, u8, u8, u8) { + ( + self.master.irr, + self.master.isr, + self.master.imr, + self.master.vector_base, + ) + } + + /// Get slave PIC state for diagnostics: (IRR, ISR, IMR, vector_base). + pub fn slave_state(&self) -> (u8, u8, u8, u8) { + ( + self.slave.irr, + self.slave.isr, + self.slave.imr, + self.slave.vector_base, + ) + } + + /// Check if the given I/O port belongs to either PIC. + pub fn handles_port(&self, port: u16) -> bool { + matches!( + port, + PIC_MASTER_CMD | PIC_MASTER_DATA | PIC_SLAVE_CMD | PIC_SLAVE_DATA + ) + } +} + +impl IoHandler for Pic { + fn io_read(&self, port: u16, _size: u8) -> u32 { + let val = match port { + PIC_MASTER_CMD => self.master.read_command(), + PIC_MASTER_DATA => self.master.read_data(), + PIC_SLAVE_CMD => self.slave.read_command(), + PIC_SLAVE_DATA => self.slave.read_data(), + _ => 0xFF, + }; + val as u32 + } + + fn io_write(&self, port: u16, _size: u8, data: u32) { + // IoHandler takes &self, but we need &mut self for PIC state. + // This is a design limitation — for now, the boot_kernel example + // uses Pic directly with &mut self methods. This trait impl is + // provided for interface compatibility but should not be used + // when mutation is needed. + // + // In practice, the vCPU loop will call write_port() directly. + let _ = (port, data); + } +} + +impl Pic { + /// Write to a PIC I/O port (mutable version for the vCPU loop). + pub fn write_port(&mut self, port: u16, data: u8) { + match port { + PIC_MASTER_CMD => self.master.write_command(data), + PIC_MASTER_DATA => self.master.write_data(data), + PIC_SLAVE_CMD => self.slave.write_command(data), + PIC_SLAVE_DATA => self.slave.write_data(data), + _ => {} + } + } + + /// Read from a PIC I/O port. + pub fn read_port(&self, port: u16) -> u8 { + match port { + PIC_MASTER_CMD => self.master.read_command(), + PIC_MASTER_DATA => self.master.read_data(), + PIC_SLAVE_CMD => self.slave.read_command(), + PIC_SLAVE_DATA => self.slave.read_data(), + _ => 0xFF, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---- PicChip unit tests ---- + + #[test] + fn test_pic_chip_initial_state() { + let chip = PicChip::new(); + assert_eq!(chip.irr, 0); + assert_eq!(chip.isr, 0); + assert_eq!(chip.imr, 0xFF, "all IRQs masked initially"); + assert_eq!(chip.vector_base, 0); + assert_eq!(chip.init_state, InitState::Ready); + } + + #[test] + fn test_pic_chip_raise_irq_while_masked() { + let mut chip = PicChip::new(); + chip.raise_irq(0); + assert_eq!(chip.irr, 0x01); + // All masked, so no pending. + assert_eq!(chip.pending_irq(), None); + } + + #[test] + fn test_pic_chip_raise_and_unmask() { + let mut chip = PicChip::new(); + chip.imr = 0; // Unmask all. + chip.vector_base = 0x20; + chip.raise_irq(0); + assert_eq!(chip.pending_irq(), Some(0)); + + let vector = chip.acknowledge(); + assert_eq!(vector, Some(0x20)); + assert_eq!(chip.irr, 0, "IRR cleared after acknowledge"); + assert_eq!(chip.isr, 0x01, "ISR set after acknowledge"); + } + + #[test] + fn test_pic_chip_priority_order() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x20; + + // Raise IRQ 3 and IRQ 1 — IRQ 1 has higher priority. + chip.raise_irq(3); + chip.raise_irq(1); + assert_eq!(chip.pending_irq(), Some(1)); + + let vector = chip.acknowledge(); + assert_eq!(vector, Some(0x21)); // 0x20 + 1 + + // IRQ 3 is blocked while IRQ 1 is in-service (lower priority). + assert_eq!( + chip.pending_irq(), + None, + "IRQ 3 blocked while IRQ 1 in-service" + ); + + // After EOI for IRQ 1, IRQ 3 becomes deliverable. + chip.write_command(0x61); // Specific EOI for IRQ 1. + assert_eq!(chip.pending_irq(), Some(3)); + } + + #[test] + fn test_pic_chip_isr_blocks_lower_priority() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x20; + + chip.raise_irq(0); + chip.acknowledge(); // IRQ 0 now in ISR. + + // Raise IRQ 1 — lower priority than IRQ 0. + // With proper 8259A priority masking, IRQ 1 is blocked while + // IRQ 0 is in-service (all equal-or-lower priority blocked). + chip.raise_irq(1); + assert_eq!( + chip.pending_irq(), + None, + "IRQ 1 must be blocked while IRQ 0 is in-service" + ); + + // After EOI for IRQ 0, IRQ 1 becomes deliverable. + chip.write_command(0x60); // Specific EOI for IRQ 0. + assert_eq!(chip.isr, 0, "ISR cleared after specific EOI"); + assert_eq!(chip.pending_irq(), Some(1), "IRQ 1 deliverable after EOI"); + } + + #[test] + fn test_pic_chip_nonspecific_eoi() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x20; + + chip.raise_irq(0); + chip.acknowledge(); // IRQ 0 in ISR. + assert_eq!(chip.isr, 0x01); + + // Non-specific EOI (OCW2 with bit 5 set). + chip.write_command(0x20); + assert_eq!(chip.isr, 0, "ISR cleared by EOI"); + } + + #[test] + fn test_pic_chip_specific_eoi() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x20; + + // Acknowledge IRQ 0, then EOI it, then acknowledge IRQ 2. + chip.raise_irq(0); + chip.raise_irq(2); + chip.acknowledge(); // IRQ 0 acknowledged → ISR bit 0. + assert_eq!(chip.isr, 0x01); + + // IRQ 2 is blocked while IRQ 0 in-service (priority masking). + assert_eq!(chip.pending_irq(), None); + + // EOI IRQ 0, then IRQ 2 becomes deliverable. + chip.write_command(0x60); // Specific EOI for IRQ 0. + assert_eq!(chip.isr, 0x00); + chip.acknowledge(); // IRQ 2 acknowledged → ISR bit 2. + assert_eq!(chip.isr, 0x04); + + // Specific EOI for IRQ 2 (OCW2: 0x60 | 2 = 0x62). + chip.write_command(0x62); + assert_eq!(chip.isr, 0x00, "ISR should be clear after both EOIs"); + } + + #[test] + fn test_pic_chip_icw_sequence() { + let mut chip = PicChip::new(); + + // ICW1: start init, ICW4 needed. + chip.write_command(0x11); + assert_eq!(chip.init_state, InitState::WaitIcw2); + assert!(chip.icw4_needed); + + // ICW2: vector base = 0x20. + chip.write_data(0x20); + assert_eq!(chip.vector_base, 0x20); + assert_eq!(chip.init_state, InitState::WaitIcw3); + + // ICW3: cascade config. + chip.write_data(0x04); // Master: slave on IRQ 2. + assert_eq!(chip.init_state, InitState::WaitIcw4); + + // ICW4: 8086 mode. + chip.write_data(0x01); + assert_eq!(chip.init_state, InitState::Ready); + } + + #[test] + fn test_pic_chip_icw_without_icw4() { + let mut chip = PicChip::new(); + + // ICW1 without ICW4. + chip.write_command(0x10); + assert!(!chip.icw4_needed); + + // ICW2. + chip.write_data(0x28); + assert_eq!(chip.vector_base, 0x28); + + // ICW3 → goes straight to Ready. + chip.write_data(0x02); + assert_eq!(chip.init_state, InitState::Ready); + } + + #[test] + fn test_pic_chip_imr_read_write() { + let mut chip = PicChip::new(); + + // After init, writing data port sets IMR. + chip.write_data(0xFB); // Mask all except IRQ 2. + assert_eq!(chip.imr, 0xFB); + assert_eq!(chip.read_data(), 0xFB); + } + + #[test] + fn test_pic_chip_read_irr_isr() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x20; + + chip.raise_irq(3); + + // Default read = IRR. + assert_eq!(chip.read_command(), 0x08); // bit 3. + + // OCW3: read ISR. + chip.write_command(0x0B); + assert_eq!(chip.read_command(), 0); // No ISR yet. + + chip.acknowledge(); // IRQ 3 → ISR. + assert_eq!(chip.read_command(), 0x08); // ISR bit 3. + + // OCW3: read IRR. + chip.write_command(0x0A); + assert_eq!(chip.read_command(), 0); // IRR cleared. + } + + #[test] + fn test_pic_chip_auto_eoi() { + let mut chip = PicChip::new(); + + // Init with auto-EOI. + chip.write_command(0x11); // ICW1. + chip.write_data(0x20); // ICW2. + chip.write_data(0x00); // ICW3. + chip.write_data(0x03); // ICW4: 8086 mode + auto-EOI. + assert!(chip.auto_eoi); + + chip.imr = 0; + chip.raise_irq(0); + let vector = chip.acknowledge(); + assert_eq!(vector, Some(0x20)); + assert_eq!(chip.isr, 0, "ISR should not be set in auto-EOI mode"); + } + + /// Validates the fix for the WHPX flakiness root cause: PIT (IRQ 0) + /// in-service must block vsock (IRQ 6) delivery. Without this fix, + /// both interrupts end up in ISR simultaneously (ISR=0x41), causing + /// a deadlock where the kernel can't service either handler. + #[test] + fn test_pic_chip_pit_blocks_vsock_priority() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x30; // Linux programs master PIC to base 0x30. + + // PIT fires (IRQ 0) and gets acknowledged. + chip.raise_irq(0); + assert_eq!(chip.acknowledge(), Some(0x30)); + assert_eq!(chip.isr, 0x01); // PIT in-service. + + // While PIT handler runs, vsock (IRQ 6) fires. + chip.raise_irq(6); + + // IRQ 6 must NOT be deliverable (lower priority than IRQ 0). + assert_eq!( + chip.pending_irq(), + None, + "vsock IRQ 6 must be blocked while PIT IRQ 0 is in-service" + ); + + // Kernel sends specific EOI for PIT (0x60 | 0 = 0x60). + chip.write_command(0x60); + assert_eq!(chip.isr, 0x00); + + // Now vsock IRQ 6 is deliverable. + assert_eq!(chip.pending_irq(), Some(6)); + assert_eq!(chip.acknowledge(), Some(0x36)); + assert_eq!(chip.isr, 0x40); // Only vsock in-service, NOT 0x41. + } + + #[test] + fn test_pic_chip_higher_priority_preempts() { + let mut chip = PicChip::new(); + chip.imr = 0; + chip.vector_base = 0x30; + + // IRQ 6 (vsock) in-service. + chip.raise_irq(6); + chip.acknowledge(); + assert_eq!(chip.isr, 0x40); + + // IRQ 0 (PIT) fires — higher priority, should preempt. + chip.raise_irq(0); + assert_eq!( + chip.pending_irq(), + Some(0), + "higher-priority IRQ 0 should preempt IRQ 6" + ); + } + + #[test] + fn test_pic_chip_clear_irq() { + let mut chip = PicChip::new(); + chip.raise_irq(5); + assert_eq!(chip.irr, 0x20); + chip.clear_irq(5); + assert_eq!(chip.irr, 0); + } + + // ---- Dual Pic tests ---- + + #[test] + fn test_pic_new_no_pending() { + let pic = Pic::new(); + assert!(!pic.has_pending()); + } + + #[test] + fn test_pic_master_irq_lifecycle() { + let mut pic = Pic::new(); + + // Program master PIC: vector base 0x20, unmask IRQ 0. + pic.write_port(PIC_MASTER_CMD, 0x11); + pic.write_port(PIC_MASTER_DATA, 0x20); + pic.write_port(PIC_MASTER_DATA, 0x04); + pic.write_port(PIC_MASTER_DATA, 0x01); + pic.write_port(PIC_MASTER_DATA, 0xFE); // Unmask only IRQ 0. + + pic.raise_irq(0); + assert!(pic.has_pending()); + + let vector = pic.acknowledge(); + assert_eq!(vector, Some(0x20)); + assert!(!pic.has_pending()); + + // EOI. + pic.write_port(PIC_MASTER_CMD, 0x20); + } + + #[test] + fn test_pic_slave_irq_lifecycle() { + let mut pic = Pic::new(); + + // Program master: vector 0x20, unmask IRQ 2 (cascade). + pic.write_port(PIC_MASTER_CMD, 0x11); + pic.write_port(PIC_MASTER_DATA, 0x20); + pic.write_port(PIC_MASTER_DATA, 0x04); + pic.write_port(PIC_MASTER_DATA, 0x01); + pic.write_port(PIC_MASTER_DATA, 0xFB); // Unmask only IRQ 2. + + // Program slave: vector 0x28, unmask IRQ 0 (= global IRQ 8). + pic.write_port(PIC_SLAVE_CMD, 0x11); + pic.write_port(PIC_SLAVE_DATA, 0x28); + pic.write_port(PIC_SLAVE_DATA, 0x02); + pic.write_port(PIC_SLAVE_DATA, 0x01); + pic.write_port(PIC_SLAVE_DATA, 0xFE); // Unmask only slave IRQ 0. + + // Raise IRQ 8 (slave IRQ 0). + pic.raise_irq(8); + assert!(pic.has_pending()); + + let vector = pic.acknowledge(); + assert_eq!(vector, Some(0x28)); // Slave vector base + 0. + assert!(!pic.has_pending()); + + // EOI to both slave and master. + pic.write_port(PIC_SLAVE_CMD, 0x20); + pic.write_port(PIC_MASTER_CMD, 0x20); + } + + #[test] + fn test_pic_handles_port() { + let pic = Pic::new(); + assert!(pic.handles_port(PIC_MASTER_CMD)); + assert!(pic.handles_port(PIC_MASTER_DATA)); + assert!(pic.handles_port(PIC_SLAVE_CMD)); + assert!(pic.handles_port(PIC_SLAVE_DATA)); + assert!(!pic.handles_port(0x22)); + assert!(!pic.handles_port(0x3F8)); + } + + #[test] + fn test_pic_read_port() { + let mut pic = Pic::new(); + + // Init master. + pic.write_port(PIC_MASTER_CMD, 0x11); + pic.write_port(PIC_MASTER_DATA, 0x20); + pic.write_port(PIC_MASTER_DATA, 0x04); + pic.write_port(PIC_MASTER_DATA, 0x01); + + // Set IMR. + pic.write_port(PIC_MASTER_DATA, 0xAB); + assert_eq!(pic.read_port(PIC_MASTER_DATA), 0xAB); + } + + #[test] + fn test_pic_multiple_master_irqs() { + let mut pic = Pic::new(); + + // Init and unmask all. + pic.write_port(PIC_MASTER_CMD, 0x11); + pic.write_port(PIC_MASTER_DATA, 0x20); + pic.write_port(PIC_MASTER_DATA, 0x04); + pic.write_port(PIC_MASTER_DATA, 0x01); + pic.write_port(PIC_MASTER_DATA, 0x00); // Unmask all. + + pic.raise_irq(3); + pic.raise_irq(1); + + // IRQ 1 is higher priority. + assert_eq!(pic.acknowledge(), Some(0x21)); + pic.write_port(PIC_MASTER_CMD, 0x20); // EOI. + + // Now IRQ 3. + assert_eq!(pic.acknowledge(), Some(0x23)); + pic.write_port(PIC_MASTER_CMD, 0x20); // EOI. + + assert!(!pic.has_pending()); + } + + #[test] + fn test_pic_init_resets_state() { + let mut pic = Pic::new(); + + // Set some state. + pic.master.irr = 0xFF; + pic.master.isr = 0xFF; + pic.master.imr = 0xFF; + + // Re-init should reset IRR, ISR, IMR. + pic.write_port(PIC_MASTER_CMD, 0x11); + assert_eq!(pic.master.irr, 0); + assert_eq!(pic.master.isr, 0); + assert_eq!(pic.master.imr, 0); + } + + #[test] + fn test_pic_io_handler_read() { + let pic = Pic::new(); + // Reading data port returns IMR (0xFF initially). + let val = pic.io_read(PIC_MASTER_DATA, 1); + assert_eq!(val, 0xFF); + } + + #[test] + fn test_pic_masked_irq_not_pending() { + let mut pic = Pic::new(); + + // Init master, mask IRQ 0. + pic.write_port(PIC_MASTER_CMD, 0x11); + pic.write_port(PIC_MASTER_DATA, 0x20); + pic.write_port(PIC_MASTER_DATA, 0x04); + pic.write_port(PIC_MASTER_DATA, 0x01); + pic.write_port(PIC_MASTER_DATA, 0x01); // Mask IRQ 0. + + pic.raise_irq(0); + assert!(!pic.has_pending(), "masked IRQ should not be pending"); + } +} diff --git a/src/vmm/src/windows/devices/pit.rs b/src/vmm/src/windows/devices/pit.rs new file mode 100644 index 000000000..586e385c6 --- /dev/null +++ b/src/vmm/src/windows/devices/pit.rs @@ -0,0 +1,763 @@ +//! 8254 PIT (Programmable Interval Timer) emulation. +//! +//! Emulates the three counters of the 8254/8253 PIT at I/O ports 0x40-0x43: +//! - Counter 0 (port 0x40): System timer, connected to PIC IRQ 0. +//! - Counter 1 (port 0x41): DRAM refresh (not emulated, returns 0). +//! - Counter 2 (port 0x42): PC speaker (not emulated, returns 0). +//! - Control word (port 0x43): Mode/command register. +//! +//! The PIT oscillator runs at 1,193,182 Hz. The kernel programs a reload +//! value and the counter counts down; when it reaches zero, it fires IRQ 0 +//! and reloads. +//! +//! Only counter 0 modes 2 (rate generator) and 3 (square wave) are emulated, +//! as these are the only modes Linux uses for the system timer. + +/// PIT I/O port: Counter 0 data. +pub const PIT_COUNTER0: u16 = 0x40; +/// PIT I/O port: Counter 1 data. +pub const PIT_COUNTER1: u16 = 0x41; +/// PIT I/O port: Counter 2 data. +pub const PIT_COUNTER2: u16 = 0x42; +/// PIT I/O port: Control word register. +pub const PIT_COMMAND: u16 = 0x43; + +/// PIT oscillator frequency in Hz. +pub const PIT_FREQUENCY: u64 = 1_193_182; + +/// Nanoseconds per PIT tick (approximately 838.1 ns). +/// Calculated as 1_000_000_000 / 1_193_182 ≈ 838. +/// We use fixed-point math in tick() for accuracy. +const NS_PER_SEC: u64 = 1_000_000_000; + +/// Counter operating mode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CounterMode { + /// Mode 0: Interrupt on terminal count. + InterruptOnTerminal, + /// Mode 2: Rate generator (periodic, fires on reload). + RateGenerator, + /// Mode 3: Square wave generator (periodic). + SquareWave, + /// Other modes (not emulated). + Other(u8), +} + +/// Access mode for reading/writing counter values. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AccessMode { + /// Low byte only. + Low, + /// High byte only. + High, + /// Low byte then high byte. + LoThenHi, +} + +/// Read/write state for two-byte access mode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum RwState { + /// Next byte is the low byte. + Low, + /// Next byte is the high byte. + High, +} + +/// State for a single PIT counter. +#[derive(Debug)] +struct PitCounter { + /// Reload value (what the counter reloads to after reaching zero). + reload: u16, + /// Whether the reload value has been fully written. + reload_ready: bool, + /// Operating mode. + mode: CounterMode, + /// Access mode (lo, hi, or lo-hi byte). + access: AccessMode, + /// Write state for two-byte writes. + write_state: RwState, + /// Read state for two-byte reads. + read_state: RwState, + /// Latched count value (for latch command). + latched_value: Option, + /// Low byte of partial reload write. + write_low: u8, + /// Accumulated nanoseconds of elapsed time (for fractional ticks). + ns_accumulator: u64, +} + +impl PitCounter { + fn new() -> Self { + PitCounter { + reload: 0, + reload_ready: false, + mode: CounterMode::Other(0), + access: AccessMode::LoThenHi, + write_state: RwState::Low, + read_state: RwState::Low, + latched_value: None, + write_low: 0, + ns_accumulator: 0, + } + } + + /// Set the counter mode and access mode from a control word. + fn set_control(&mut self, mode: CounterMode, access: AccessMode) { + self.mode = mode; + self.access = access; + self.write_state = RwState::Low; + self.read_state = RwState::Low; + self.reload_ready = false; + } + + /// Write a data byte to this counter's data port. + fn write_data(&mut self, data: u8) { + match self.access { + AccessMode::Low => { + self.reload = data as u16; + self.reload_ready = true; + self.ns_accumulator = 0; + } + AccessMode::High => { + self.reload = (data as u16) << 8; + self.reload_ready = true; + self.ns_accumulator = 0; + } + AccessMode::LoThenHi => match self.write_state { + RwState::Low => { + self.write_low = data; + self.write_state = RwState::High; + } + RwState::High => { + self.reload = self.write_low as u16 | ((data as u16) << 8); + self.write_state = RwState::Low; + self.reload_ready = true; + self.ns_accumulator = 0; + } + }, + } + } + + /// Read a data byte from this counter's data port. + fn read_data(&mut self) -> u8 { + let value = self.latched_value.unwrap_or_else(|| self.current_count()); + + match self.access { + AccessMode::Low => { + self.latched_value = None; + value as u8 + } + AccessMode::High => { + self.latched_value = None; + (value >> 8) as u8 + } + AccessMode::LoThenHi => match self.read_state { + RwState::Low => { + self.read_state = RwState::High; + value as u8 + } + RwState::High => { + self.read_state = RwState::Low; + self.latched_value = None; + (value >> 8) as u8 + } + }, + } + } + + /// Latch the current count value for reading. + fn latch(&mut self) { + if self.latched_value.is_none() { + self.latched_value = Some(self.current_count()); + } + } + + /// Effective reload value: 0 means 65536 per 8254 specification. + /// + /// In the real 8254 PIT, a reload value of 0 is treated as 65536 + /// (the maximum 16-bit count). This matches BIOS behavior where the + /// PIT is initialized with reload=0 giving ~18.2 Hz. + fn effective_reload(&self) -> u64 { + if self.reload == 0 { + 65536 + } else { + self.reload as u64 + } + } + + /// Compute the current counter value based on accumulated time. + /// + /// A real 8254 counts down from the reload value to 0. Software + /// reads the counter (via latch or direct read) to measure elapsed + /// time. Without this, counter reads return the static reload value + /// and Linux calibration loops that poll the counter never terminate. + fn current_count(&self) -> u16 { + if !self.reload_ready { + // Counter not (yet) programmed, or Mode 0 finished (one-shot). + // Mode 0 after terminal count: counter sits at 0. + if matches!(self.mode, CounterMode::InterruptOnTerminal) { + return 0; + } + return self.reload; + } + let reload = self.effective_reload(); + // How many PIT ticks into the current reload cycle? + let ticks_in_period = + (self.ns_accumulator as u128 * PIT_FREQUENCY as u128) / NS_PER_SEC as u128; + let position = (ticks_in_period as u64) % reload; + // Counter counts down: reload → 0. + (reload - position) as u16 + } + + /// Advance the counter by `elapsed_ns` nanoseconds. + /// + /// Returns the number of times the counter reached zero (fired). + fn tick(&mut self, elapsed_ns: u64) -> u64 { + if !self.reload_ready { + return 0; + } + + let reload = self.effective_reload(); + + match self.mode { + CounterMode::RateGenerator | CounterMode::SquareWave => { + // Accumulate elapsed time. + self.ns_accumulator += elapsed_ns; + + // Calculate how many PIT ticks have elapsed. + // ticks = accumulated_ns * PIT_FREQUENCY / NS_PER_SEC + // To avoid overflow, use u128 for intermediate calculation. + let total_ticks = + (self.ns_accumulator as u128 * PIT_FREQUENCY as u128) / NS_PER_SEC as u128; + + // How many full reload cycles is that? + let fires = total_ticks / reload as u128; + + // Subtract consumed nanoseconds (keep remainder in accumulator). + // consumed_ns = fires * reload * NS_PER_SEC / PIT_FREQUENCY + let consumed_ns = + (fires * reload as u128 * NS_PER_SEC as u128) / PIT_FREQUENCY as u128; + self.ns_accumulator -= consumed_ns as u64; + + fires as u64 + } + CounterMode::InterruptOnTerminal => { + // Mode 0: fires once when count reaches zero. + self.ns_accumulator += elapsed_ns; + let total_ticks = + (self.ns_accumulator as u128 * PIT_FREQUENCY as u128) / NS_PER_SEC as u128; + if total_ticks >= reload as u128 { + self.reload_ready = false; // One-shot: stop after firing. + self.ns_accumulator = 0; + 1 + } else { + 0 + } + } + CounterMode::Other(_) => 0, + } + } +} + +/// 8254 PIT emulation. +pub struct Pit { + counters: [PitCounter; 3], +} + +impl Default for Pit { + fn default() -> Self { + Self::new() + } +} + +impl Pit { + /// Create a new PIT with BIOS-compatible default state. + /// + /// Counter 0 is pre-programmed in Mode 2 (rate generator) with a + /// reload value of 0 (= 65536, giving ~18.2 Hz). This matches real + /// PC behavior where the BIOS initializes the PIT before handing + /// off to the OS. Without this, timer interrupts won't fire until + /// the kernel programs the PIT, but the kernel may depend on timer + /// interrupts *before* it programs the PIT (e.g., jiffies-based + /// timeouts in early hardware probing). + pub fn new() -> Self { + let mut counter0 = PitCounter::new(); + counter0.mode = CounterMode::RateGenerator; + counter0.reload = 0; // 0 = 65536 per 8254 spec → ~18.2 Hz + counter0.reload_ready = true; + Pit { + counters: [counter0, PitCounter::new(), PitCounter::new()], + } + } + + /// Check if the given I/O port belongs to the PIT. + pub fn handles_port(&self, port: u16) -> bool { + (PIT_COUNTER0..=PIT_COMMAND).contains(&port) + } + + /// Write to a PIT I/O port. + pub fn write_port(&mut self, port: u16, data: u8) { + match port { + PIT_COUNTER0 => self.counters[0].write_data(data), + PIT_COUNTER1 => self.counters[1].write_data(data), + PIT_COUNTER2 => self.counters[2].write_data(data), + PIT_COMMAND => self.write_command(data), + _ => {} + } + } + + /// Read from a PIT I/O port. + pub fn read_port(&mut self, port: u16) -> u8 { + match port { + PIT_COUNTER0 => self.counters[0].read_data(), + PIT_COUNTER1 => self.counters[1].read_data(), + PIT_COUNTER2 => self.counters[2].read_data(), + PIT_COMMAND => 0, // Command register is write-only. + _ => 0, + } + } + + /// Advance all counters by `elapsed_ns` nanoseconds. + /// + /// Returns the number of times counter 0 fired (should raise IRQ 0 + /// for each fire). Counters 1 and 2 are also ticked so their + /// `ns_accumulator` stays current — required for `current_count()` + /// to return meaningful values when Linux reads these counters + /// (e.g., PIT counter 2 for TSC calibration). + pub fn tick(&mut self, elapsed_ns: u64) -> u64 { + let fires = self.counters[0].tick(elapsed_ns); + self.counters[1].tick(elapsed_ns); + self.counters[2].tick(elapsed_ns); + fires + } + + /// Parse and apply a control word written to port 0x43. + fn write_command(&mut self, data: u8) { + let counter_idx = ((data >> 6) & 0x03) as usize; + let access_bits = (data >> 4) & 0x03; + let mode_bits = (data >> 1) & 0x07; + + // Counter 3 is invalid (read-back command in 8254, not emulated). + if counter_idx >= 3 { + return; + } + + let access = match access_bits { + 0 => { + // Latch command: latch the current count. + self.counters[counter_idx].latch(); + return; + } + 1 => AccessMode::Low, + 2 => AccessMode::High, + 3 => AccessMode::LoThenHi, + _ => unreachable!(), + }; + + let mode = match mode_bits { + 0 => CounterMode::InterruptOnTerminal, + 2 | 6 => CounterMode::RateGenerator, + 3 | 7 => CounterMode::SquareWave, + m => CounterMode::Other(m), + }; + + self.counters[counter_idx].set_control(mode, access); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---- PitCounter unit tests ---- + + #[test] + fn test_counter_initial_state() { + let counter = PitCounter::new(); + assert_eq!(counter.reload, 0); + assert!(!counter.reload_ready); + assert_eq!(counter.ns_accumulator, 0); + } + + #[test] + fn test_counter_tick_not_ready() { + let mut counter = PitCounter::new(); + assert_eq!(counter.tick(1_000_000), 0, "should not fire when not ready"); + } + + #[test] + fn test_counter_rate_generator_fires() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::RateGenerator; + counter.reload = 1193; // ~1000 Hz (1 ms period) + counter.reload_ready = true; + + // 1 ms = 1_000_000 ns, should fire ~1 time. + let fires = counter.tick(1_000_000); + assert_eq!(fires, 1); + } + + #[test] + fn test_counter_rate_generator_multiple_fires() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::RateGenerator; + counter.reload = 1193; // ~1000 Hz + counter.reload_ready = true; + + // 10 ms = 10_000_000 ns, should fire ~10 times. + let fires = counter.tick(10_000_000); + assert!( + fires >= 9 && fires <= 11, + "expected ~10 fires, got {}", + fires + ); + } + + #[test] + fn test_counter_accumulates_remainder() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::RateGenerator; + counter.reload = 11932; // ~100 Hz (~10.0005 ms period) + counter.reload_ready = true; + + // Tick 5ms — not enough for one full period. + let fires1 = counter.tick(5_000_000); + assert_eq!(fires1, 0); + + // Tick another 6ms — total 11ms, should fire once. + let fires2 = counter.tick(6_000_000); + assert_eq!(fires2, 1); + } + + #[test] + fn test_counter_square_wave_fires() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::SquareWave; + counter.reload = 1193; + counter.reload_ready = true; + + let fires = counter.tick(1_000_000); + assert_eq!(fires, 1); + } + + #[test] + fn test_counter_mode0_fires_once() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::InterruptOnTerminal; + counter.reload = 1193; + counter.reload_ready = true; + + let fires1 = counter.tick(1_000_000); + assert_eq!(fires1, 1); + + // Mode 0 is one-shot: should not fire again. + let fires2 = counter.tick(1_000_000); + assert_eq!(fires2, 0); + } + + #[test] + fn test_counter_zero_reload_means_65536() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::RateGenerator; + counter.reload = 0; // 0 = 65536 per 8254 spec. + counter.reload_ready = true; + + // 65536 ticks at 1,193,182 Hz → ~54.9ms period. + // 100ms should produce ~1 fire. + let fires = counter.tick(100_000_000); + assert!(fires >= 1 && fires <= 2, "expected ~1 fire, got {}", fires); + } + + #[test] + fn test_counter_write_lo_hi_byte() { + let mut counter = PitCounter::new(); + counter.access = AccessMode::LoThenHi; + counter.mode = CounterMode::RateGenerator; + + // Write low byte first, then high byte. + counter.write_data(0x00); // Low byte. + assert!(!counter.reload_ready); + + counter.write_data(0x10); // High byte → reload = 0x1000. + assert!(counter.reload_ready); + assert_eq!(counter.reload, 0x1000); + } + + #[test] + fn test_counter_write_lo_byte_only() { + let mut counter = PitCounter::new(); + counter.access = AccessMode::Low; + counter.mode = CounterMode::RateGenerator; + + counter.write_data(0x42); + assert!(counter.reload_ready); + assert_eq!(counter.reload, 0x42); + } + + #[test] + fn test_counter_write_hi_byte_only() { + let mut counter = PitCounter::new(); + counter.access = AccessMode::High; + counter.mode = CounterMode::RateGenerator; + + counter.write_data(0x42); + assert!(counter.reload_ready); + assert_eq!(counter.reload, 0x4200); + } + + #[test] + fn test_counter_read_lo_hi_byte() { + let mut counter = PitCounter::new(); + counter.access = AccessMode::LoThenHi; + counter.reload = 0x1234; + + let lo = counter.read_data(); + assert_eq!(lo, 0x34); + + let hi = counter.read_data(); + assert_eq!(hi, 0x12); + } + + #[test] + fn test_counter_latch() { + let mut counter = PitCounter::new(); + counter.access = AccessMode::LoThenHi; + counter.reload = 0xABCD; + + counter.latch(); + assert_eq!(counter.latched_value, Some(0xABCD)); + + // Read should return latched value. + let lo = counter.read_data(); + assert_eq!(lo, 0xCD); + let hi = counter.read_data(); + assert_eq!(hi, 0xAB); + + // Latched value should be consumed. + assert_eq!(counter.latched_value, None); + } + + #[test] + fn test_counter_latch_only_once() { + let mut counter = PitCounter::new(); + counter.reload = 0x1111; + + counter.latch(); + counter.reload = 0x2222; // Change after latch. + counter.latch(); // Should NOT overwrite first latch. + + assert_eq!(counter.latched_value, Some(0x1111)); + } + + // ---- Pit (full device) tests ---- + + #[test] + fn test_pit_handles_port() { + let pit = Pit::new(); + assert!(pit.handles_port(PIT_COUNTER0)); + assert!(pit.handles_port(PIT_COUNTER1)); + assert!(pit.handles_port(PIT_COUNTER2)); + assert!(pit.handles_port(PIT_COMMAND)); + assert!(!pit.handles_port(0x44)); + assert!(!pit.handles_port(0x3F)); + } + + #[test] + fn test_pit_program_counter0_rate_generator() { + let mut pit = Pit::new(); + + // Program counter 0 in rate generator mode, lo-hi access. + // Control word: counter=0 (bits 7-6=00), access=lo-hi (bits 5-4=11), + // mode=2 (bits 3-1=010), BCD=0 (bit 0=0) + // = 0b_00_11_010_0 = 0x34 + pit.write_port(PIT_COMMAND, 0x34); + + // Write reload value: 11932 = 0x2E9C (100 Hz, 10ms period). + pit.write_port(PIT_COUNTER0, 0x9C); // Low byte. + pit.write_port(PIT_COUNTER0, 0x2E); // High byte. + + assert_eq!(pit.counters[0].reload, 0x2E9C); + assert!(pit.counters[0].reload_ready); + + // Tick 11ms — one period is ~10.0005ms, so 11ms is enough for one fire. + let fires = pit.tick(11_000_000); + assert_eq!(fires, 1); + } + + #[test] + fn test_pit_program_counter0_square_wave() { + let mut pit = Pit::new(); + + // Counter 0, lo-hi, mode 3 (square wave), binary. + // = 0b_00_11_011_0 = 0x36 + pit.write_port(PIT_COMMAND, 0x36); + + // Reload 11932 = ~100 Hz (~10.0005ms period). + pit.write_port(PIT_COUNTER0, 0x9C); + pit.write_port(PIT_COUNTER0, 0x2E); + + let fires = pit.tick(11_000_000); + assert_eq!(fires, 1); + } + + #[test] + fn test_pit_latch_command() { + let mut pit = Pit::new(); + + // Program counter 0. + pit.write_port(PIT_COMMAND, 0x34); + pit.write_port(PIT_COUNTER0, 0x00); + pit.write_port(PIT_COUNTER0, 0x10); // reload = 0x1000 + + // Latch counter 0: control word with access=00. + pit.write_port(PIT_COMMAND, 0x00); + + // Read latched value. + let lo = pit.read_port(PIT_COUNTER0); + let hi = pit.read_port(PIT_COUNTER0); + let val = lo as u16 | ((hi as u16) << 8); + assert_eq!(val, 0x1000); + } + + #[test] + fn test_pit_command_register_read_is_zero() { + let mut pit = Pit::new(); + assert_eq!(pit.read_port(PIT_COMMAND), 0); + } + + #[test] + fn test_pit_counter1_counter2_ignored() { + let mut pit = Pit::new(); + + // Programming counters 1 and 2 shouldn't affect tick(). + pit.write_port(PIT_COMMAND, 0x74); // Counter 1, lo-hi, mode 2. + pit.write_port(PIT_COUNTER1, 0x00); + pit.write_port(PIT_COUNTER1, 0x01); + + // Tick should only look at counter 0. + assert_eq!(pit.tick(10_000_000), 0); + } + + #[test] + fn test_pit_fires_with_bios_defaults() { + let mut pit = Pit::new(); + // PIT starts pre-programmed at ~18.2 Hz (reload 0 = 65536). + // 100ms should produce ~1-2 fires. + let fires = pit.tick(100_000_000); + assert!( + fires >= 1 && fires <= 2, + "expected ~1-2 fires from BIOS defaults, got {}", + fires + ); + } + + #[test] + fn test_pit_linux_typical_1000hz() { + let mut pit = Pit::new(); + + // Linux HZ=1000 programs PIT with reload = 1193 (≈1ms period). + pit.write_port(PIT_COMMAND, 0x34); + pit.write_port(PIT_COUNTER0, (1193 & 0xFF) as u8); + pit.write_port(PIT_COUNTER0, (1193 >> 8) as u8); + + // 1 second = 1_000_000_000 ns → should fire ~1000 times. + let fires = pit.tick(1_000_000_000); + assert!( + fires >= 998 && fires <= 1002, + "expected ~1000 fires for HZ=1000, got {}", + fires + ); + } + + #[test] + fn test_pit_linux_typical_100hz() { + let mut pit = Pit::new(); + + // Linux HZ=100 programs PIT with reload = 11932 (≈10ms period). + pit.write_port(PIT_COMMAND, 0x34); + pit.write_port(PIT_COUNTER0, (11932 & 0xFF) as u8); + pit.write_port(PIT_COUNTER0, (11932 >> 8) as u8); + + // 1 second → should fire ~100 times. + let fires = pit.tick(1_000_000_000); + assert!( + fires >= 99 && fires <= 101, + "expected ~100 fires for HZ=100, got {}", + fires + ); + } + + #[test] + fn test_counter_read_decrements_after_tick() { + let mut counter = PitCounter::new(); + counter.mode = CounterMode::RateGenerator; + counter.access = AccessMode::LoThenHi; + counter.reload = 11932; // ~100 Hz + counter.reload_ready = true; + + // Initially, count should equal reload (no time elapsed). + assert_eq!(counter.current_count(), 11932); + + // Tick 5ms — about half a period. Counter should be roughly half. + counter.tick(5_000_000); + let count = counter.current_count(); + assert!( + count < 11932 && count > 0, + "expected count between 0 and 11932, got {}", + count + ); + } + + #[test] + fn test_counter2_counts_down_for_tsc_calibration() { + // Linux's pit_calibrate_tsc() programs counter 2 in Mode 0 + // and reads it in a loop expecting the value to decrease. + let mut pit = Pit::new(); + + // Program counter 2: mode 0 (interrupt on terminal), lo-hi. + // Control word: counter=2 (bits 7-6=10), access=lo-hi (bits 5-4=11), + // mode=0 (bits 3-1=000), BCD=0 (bit 0=0) + // = 0b_10_11_000_0 = 0xB0 + pit.write_port(PIT_COMMAND, 0xB0); + pit.write_port(PIT_COUNTER2, 0xFF); // Low byte. + pit.write_port(PIT_COUNTER2, 0xFF); // High byte → reload = 0xFFFF. + + // Tick the PIT (simulating vCPU loop iterations). + pit.tick(10_000_000); // 10ms + + // Latch counter 2 and read it. + pit.write_port(PIT_COMMAND, 0x80); // Latch counter 2 (counter=2, access=00). + let lo = pit.read_port(PIT_COUNTER2); + let hi = pit.read_port(PIT_COUNTER2); + let count = lo as u16 | ((hi as u16) << 8); + + // Count should be less than the initial reload value. + assert!( + count < 0xFFFF, + "counter 2 should have decremented, got {:#X}", + count + ); + } + + #[test] + fn test_pit_incremental_ticks() { + let mut pit = Pit::new(); + + // HZ=100: reload = 11932. + pit.write_port(PIT_COMMAND, 0x34); + pit.write_port(PIT_COUNTER0, (11932 & 0xFF) as u8); + pit.write_port(PIT_COUNTER0, (11932 >> 8) as u8); + + // Tick in small increments (1ms each) for 100ms total. + let mut total_fires = 0u64; + for _ in 0..100 { + total_fires += pit.tick(1_000_000); + } + // 100ms at HZ=100 → should fire ~10 times. + assert!( + total_fires >= 9 && total_fires <= 11, + "expected ~10 fires over 100ms, got {}", + total_fires + ); + } +} diff --git a/src/vmm/src/windows/devices/serial.rs b/src/vmm/src/windows/devices/serial.rs new file mode 100644 index 000000000..87baee440 --- /dev/null +++ b/src/vmm/src/windows/devices/serial.rs @@ -0,0 +1,549 @@ +//! 16550 UART serial console emulation. +//! +//! Emulates a basic 16550 UART at I/O ports 0x3F8-0x3FF (COM1). +//! Provides serial console output from the guest kernel/userspace. +//! +//! Register layout (base = 0x3F8): +//! +0 (THR/RBR): Transmit/Receive buffer +//! +1 (IER): Interrupt Enable Register +//! +2 (IIR/FCR): Interrupt Identification / FIFO Control +//! +3 (LCR): Line Control Register +//! +4 (MCR): Modem Control Register +//! +5 (LSR): Line Status Register +//! +6 (MSR): Modem Status Register +//! +7 (SCR): Scratch Register +//! +//! When DLAB (bit 7 of LCR) is set: +//! +0 (DLL): Divisor Latch Low +//! +1 (DLH): Divisor Latch High + +use std::io::Write; +use std::sync::Mutex; + +use super::super::vcpu::IoHandler; + +/// COM1 base I/O port address. +pub const COM1_BASE: u16 = 0x3F8; + +/// COM1 I/O port range (8 registers). +pub const COM1_SIZE: u16 = 8; + +/// Line Status Register bit flags. +const LSR_DATA_READY: u8 = 0x01; +const LSR_THR_EMPTY: u8 = 0x20; +const LSR_IDLE: u8 = 0x40; + +/// Interrupt Identification Register values. +const IIR_NO_INTERRUPT: u8 = 0x01; +const IIR_THRE: u8 = 0x02; // Transmitter Holding Register Empty +const IIR_FIFO_ENABLED: u8 = 0xC0; + +/// IER bit: Transmitter Holding Register Empty interrupt. +const IER_THRE: u8 = 0x02; + +/// FCR bit: FIFO Enable. +const FCR_FIFO_ENABLE: u8 = 0x01; +/// FCR bit: Transmit FIFO Reset. +const FCR_TX_RESET: u8 = 0x04; + +/// 16550 FIFO depth (bytes). +const FIFO_SIZE: usize = 16; + +/// Serial port state. +struct SerialState { + /// Interrupt Enable Register. + ier: u8, + /// Line Control Register. + lcr: u8, + /// Modem Control Register. + mcr: u8, + /// Line Status Register. + lsr: u8, + /// Modem Status Register. + msr: u8, + /// Scratch register. + scr: u8, + /// Divisor Latch Low byte. + dll: u8, + /// Divisor Latch High byte. + dlh: u8, + /// Output sink. + output: Box, + /// THRE interrupt pending (set after THR write when IER THRE bit is set). + thre_pending: bool, + /// Whether FIFO mode is enabled (FCR bit 0). + fifo_enabled: bool, + /// Transmit FIFO buffer. When FIFO is enabled, bytes are buffered here + /// and flushed to `output` when the buffer is full, a newline is written, + /// or the guest reads IIR (polling for completion). + tx_fifo: Vec, +} + +/// 16550 UART emulation. +pub struct Serial { + base_port: u16, + state: Mutex, +} + +impl Serial { + /// Create a new serial port emulation at the given base I/O port. + pub fn new(base_port: u16, output: Box) -> Self { + Serial { + base_port, + state: Mutex::new(SerialState { + ier: 0, + lcr: 0, + mcr: 0, + lsr: LSR_THR_EMPTY | LSR_IDLE, // Transmitter is ready + msr: 0, + scr: 0, + dll: 0, + dlh: 0, + output, + thre_pending: false, + fifo_enabled: false, + tx_fifo: Vec::with_capacity(FIFO_SIZE), + }), + } + } + + /// Create a serial port that writes to stdout. + pub fn stdout(base_port: u16) -> Self { + Self::new(base_port, Box::new(std::io::stdout())) + } + + /// Check if the given I/O port is within this serial port's range. + pub fn handles_port(&self, port: u16) -> bool { + port >= self.base_port && port < self.base_port + COM1_SIZE + } + + /// Check if the serial device has a pending interrupt. + pub fn has_interrupt(&self) -> bool { + self.state.lock().unwrap().thre_pending + } + + /// Handle an I/O port read. + pub fn read(&self, port: u16) -> u8 { + let offset = port - self.base_port; + let mut state = self.state.lock().unwrap(); + let dlab = (state.lcr & 0x80) != 0; + + match offset { + 0 => { + if dlab { + state.dll + } else { + // RBR — receive buffer (no input support yet, return 0) + state.lsr &= !LSR_DATA_READY; + 0 + } + } + 1 => { + if dlab { + state.dlh + } else { + state.ier + } + } + 2 => { + // IIR — check for pending interrupt. + // Flush any buffered FIFO data (guest is polling for completion). + if state.fifo_enabled && !state.tx_fifo.is_empty() { + let pending: Vec = state.tx_fifo.drain(..).collect(); + let _ = state.output.write_all(&pending); + let _ = state.output.flush(); + } + if state.thre_pending { + state.thre_pending = false; + IIR_THRE | IIR_FIFO_ENABLED + } else { + IIR_NO_INTERRUPT | IIR_FIFO_ENABLED + } + } + 3 => state.lcr, + 4 => state.mcr, + 5 => { + let lsr = state.lsr; + // Reading LSR clears some bits + state.lsr &= !(LSR_DATA_READY); + lsr + } + 6 => state.msr, + 7 => state.scr, + _ => 0, + } + } + + /// Handle an I/O port write. + pub fn write(&self, port: u16, data: u8) { + let offset = port - self.base_port; + let mut state = self.state.lock().unwrap(); + let dlab = (state.lcr & 0x80) != 0; + + match offset { + 0 => { + if dlab { + state.dll = data; + } else if state.fifo_enabled { + // THR with FIFO: buffer bytes, flush on newline or full. + state.tx_fifo.push(data); + if data == b'\n' || state.tx_fifo.len() >= FIFO_SIZE { + let pending: Vec = state.tx_fifo.drain(..).collect(); + let _ = state.output.write_all(&pending); + let _ = state.output.flush(); + } + state.lsr |= LSR_THR_EMPTY | LSR_IDLE; + if state.ier & IER_THRE != 0 { + state.thre_pending = true; + } + } else { + // THR without FIFO: immediate output per byte. + let _ = state.output.write_all(&[data]); + let _ = state.output.flush(); + state.lsr |= LSR_THR_EMPTY | LSR_IDLE; + if state.ier & IER_THRE != 0 { + state.thre_pending = true; + } + } + } + 1 => { + if dlab { + state.dlh = data; + } else { + let old_ier = state.ier; + state.ier = data & 0x0F; // Only lower 4 bits valid + // Enabling THRE interrupt when THR is already empty triggers it + if (state.ier & IER_THRE != 0) + && (old_ier & IER_THRE == 0) + && (state.lsr & LSR_THR_EMPTY != 0) + { + state.thre_pending = true; + } + } + } + 2 => { + // FCR — FIFO Control Register. + state.fifo_enabled = data & FCR_FIFO_ENABLE != 0; + if data & FCR_TX_RESET != 0 { + // TX FIFO reset: flush pending data and clear buffer. + if !state.tx_fifo.is_empty() { + let pending: Vec = state.tx_fifo.drain(..).collect(); + let _ = state.output.write_all(&pending); + let _ = state.output.flush(); + } + } + if !state.fifo_enabled && !state.tx_fifo.is_empty() { + // Disabling FIFO: flush remaining data. + let pending: Vec = state.tx_fifo.drain(..).collect(); + let _ = state.output.write_all(&pending); + let _ = state.output.flush(); + } + } + 3 => state.lcr = data, + 4 => state.mcr = data & 0x1F, // Only lower 5 bits valid + 5 => {} // LSR is read-only + 6 => {} // MSR is read-only + 7 => state.scr = data, + _ => {} + } + } +} + +impl IoHandler for Serial { + fn io_read(&self, port: u16, _size: u8) -> u32 { + self.read(port) as u32 + } + + fn io_write(&self, port: u16, _size: u8, data: u32) { + self.write(port, data as u8); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{Arc, Mutex as StdMutex}; + + /// A test output sink that captures written bytes. + struct CaptureOutput { + buffer: Arc>>, + } + + impl CaptureOutput { + fn new() -> (Self, Arc>>) { + let buffer = Arc::new(StdMutex::new(Vec::new())); + ( + CaptureOutput { + buffer: buffer.clone(), + }, + buffer, + ) + } + } + + impl Write for CaptureOutput { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buffer.lock().unwrap().extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + fn create_test_serial() -> (Serial, Arc>>) { + let (output, buffer) = CaptureOutput::new(); + let serial = Serial::new(COM1_BASE, Box::new(output)); + (serial, buffer) + } + + #[test] + fn test_serial_handles_port() { + let (serial, _) = create_test_serial(); + assert!(serial.handles_port(COM1_BASE)); + assert!(serial.handles_port(COM1_BASE + 7)); + assert!(!serial.handles_port(COM1_BASE - 1)); + assert!(!serial.handles_port(COM1_BASE + 8)); + } + + #[test] + fn test_serial_lsr_initially_ready() { + let (serial, _) = create_test_serial(); + let lsr = serial.read(COM1_BASE + 5); + assert_ne!(lsr & LSR_THR_EMPTY, 0, "THR should be empty initially"); + assert_ne!(lsr & LSR_IDLE, 0, "transmitter should be idle initially"); + } + + #[test] + fn test_serial_write_character() { + let (serial, buffer) = create_test_serial(); + + serial.write(COM1_BASE, b'H'); + serial.write(COM1_BASE, b'i'); + + let captured = buffer.lock().unwrap(); + assert_eq!(&*captured, b"Hi"); + } + + #[test] + fn test_serial_write_string() { + let (serial, buffer) = create_test_serial(); + + for &byte in b"Hello, VM!\n" { + serial.write(COM1_BASE, byte); + } + + let captured = buffer.lock().unwrap(); + assert_eq!(std::str::from_utf8(&captured).unwrap(), "Hello, VM!\n"); + } + + #[test] + fn test_serial_scratch_register() { + let (serial, _) = create_test_serial(); + + serial.write(COM1_BASE + 7, 0x42); + assert_eq!(serial.read(COM1_BASE + 7), 0x42); + + serial.write(COM1_BASE + 7, 0xFF); + assert_eq!(serial.read(COM1_BASE + 7), 0xFF); + } + + #[test] + fn test_serial_dlab_divisor_latch() { + let (serial, _) = create_test_serial(); + + // Set DLAB bit in LCR + serial.write(COM1_BASE + 3, 0x80); + + // Write divisor + serial.write(COM1_BASE, 0x01); // DLL + serial.write(COM1_BASE + 1, 0x00); // DLH + + // Read divisor back + assert_eq!(serial.read(COM1_BASE), 0x01); // DLL + assert_eq!(serial.read(COM1_BASE + 1), 0x00); // DLH + + // Clear DLAB + serial.write(COM1_BASE + 3, 0x03); // 8N1 + + // Now register 0 is THR/RBR again, not DLL + // Writing should output a character, not change the divisor + let (serial2, buffer2) = create_test_serial(); + serial2.write(COM1_BASE + 3, 0x03); // 8N1, DLAB=0 + serial2.write(COM1_BASE, b'X'); + let captured = buffer2.lock().unwrap(); + assert_eq!(&*captured, b"X"); + } + + #[test] + fn test_serial_ier_mask() { + let (serial, _) = create_test_serial(); + + // IER only uses lower 4 bits + serial.write(COM1_BASE + 1, 0xFF); + assert_eq!(serial.read(COM1_BASE + 1), 0x0F); + } + + #[test] + fn test_serial_mcr_mask() { + let (serial, _) = create_test_serial(); + + // MCR only uses lower 5 bits + serial.write(COM1_BASE + 4, 0xFF); + assert_eq!(serial.read(COM1_BASE + 4), 0x1F); + } + + #[test] + fn test_serial_iir_no_interrupt() { + let (serial, _) = create_test_serial(); + + let iir = serial.read(COM1_BASE + 2); + assert_ne!(iir & IIR_NO_INTERRUPT, 0, "no interrupt should be pending"); + } + + #[test] + fn test_serial_io_handler_trait() { + let (serial, buffer) = create_test_serial(); + + // Use through IoHandler trait + serial.io_write(COM1_BASE, 1, b'A' as u32); + serial.io_write(COM1_BASE, 1, b'B' as u32); + + let lsr = serial.io_read(COM1_BASE + 5, 1); + assert_ne!(lsr & LSR_THR_EMPTY as u32, 0); + + let captured = buffer.lock().unwrap(); + assert_eq!(&*captured, b"AB"); + } + + #[test] + fn test_serial_thr_stays_ready_after_write() { + let (serial, _) = create_test_serial(); + + serial.write(COM1_BASE, b'X'); + let lsr = serial.read(COM1_BASE + 5); + assert_ne!(lsr & LSR_THR_EMPTY, 0, "THR should be ready after write"); + } + + // ---- FIFO tests ---- + + #[test] + fn test_fifo_enable_via_fcr() { + let (serial, _) = create_test_serial(); + // FIFO should be disabled initially. + assert!(!serial.state.lock().unwrap().fifo_enabled); + // Write FCR with FIFO enable bit. + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + assert!(serial.state.lock().unwrap().fifo_enabled); + // Disable FIFO. + serial.write(COM1_BASE + 2, 0); + assert!(!serial.state.lock().unwrap().fifo_enabled); + } + + #[test] + fn test_fifo_batches_output() { + let (serial, buffer) = create_test_serial(); + // Enable FIFO. + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + + // Write bytes that don't trigger flush (no newline, under FIFO_SIZE). + for &b in b"Hello" { + serial.write(COM1_BASE, b); + } + // Buffer should be empty (data is in FIFO, not flushed yet). + assert!( + buffer.lock().unwrap().is_empty(), + "FIFO should batch writes" + ); + + // Write newline to trigger flush. + serial.write(COM1_BASE, b'\n'); + let captured = buffer.lock().unwrap().clone(); + assert_eq!(captured, b"Hello\n", "newline should flush FIFO"); + } + + #[test] + fn test_fifo_flushes_on_full() { + let (serial, buffer) = create_test_serial(); + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + + // Write exactly FIFO_SIZE bytes (no newline). + for i in 0..FIFO_SIZE { + serial.write(COM1_BASE, b'A' + (i as u8 % 26)); + } + // Should have flushed on the 16th byte. + let captured = buffer.lock().unwrap().clone(); + assert_eq!(captured.len(), FIFO_SIZE, "FIFO should flush when full"); + } + + #[test] + fn test_fifo_flushes_on_iir_read() { + let (serial, buffer) = create_test_serial(); + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + + // Write partial data (no newline, under FIFO_SIZE). + for &b in b"Test" { + serial.write(COM1_BASE, b); + } + assert!(buffer.lock().unwrap().is_empty(), "not flushed yet"); + + // Read IIR — should flush the FIFO. + let _iir = serial.read(COM1_BASE + 2); + let captured = buffer.lock().unwrap().clone(); + assert_eq!(captured, b"Test", "IIR read should flush FIFO"); + } + + #[test] + fn test_fifo_disable_flushes_remaining() { + let (serial, buffer) = create_test_serial(); + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + + for &b in b"Data" { + serial.write(COM1_BASE, b); + } + assert!(buffer.lock().unwrap().is_empty()); + + // Disable FIFO — should flush remaining data. + serial.write(COM1_BASE + 2, 0); + let captured = buffer.lock().unwrap().clone(); + assert_eq!(captured, b"Data", "disabling FIFO should flush"); + } + + #[test] + fn test_fifo_tx_reset_flushes() { + let (serial, buffer) = create_test_serial(); + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + + for &b in b"Reset" { + serial.write(COM1_BASE, b); + } + assert!(buffer.lock().unwrap().is_empty()); + + // TX FIFO reset. + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE | FCR_TX_RESET); + let captured = buffer.lock().unwrap().clone(); + assert_eq!(captured, b"Reset", "TX reset should flush FIFO"); + } + + #[test] + fn test_no_fifo_immediate_output() { + let (serial, buffer) = create_test_serial(); + // FIFO disabled (default) — each byte goes out immediately. + serial.write(COM1_BASE, b'A'); + assert_eq!(buffer.lock().unwrap().as_slice(), b"A"); + serial.write(COM1_BASE, b'B'); + assert_eq!(buffer.lock().unwrap().as_slice(), b"AB"); + } + + #[test] + fn test_fifo_lsr_stays_ready() { + let (serial, _) = create_test_serial(); + serial.write(COM1_BASE + 2, FCR_FIFO_ENABLE); + + // Even with FIFO buffering, LSR should report THR empty. + serial.write(COM1_BASE, b'X'); + let lsr = serial.read(COM1_BASE + 5); + assert_ne!(lsr & LSR_THR_EMPTY, 0, "THR should be ready in FIFO mode"); + } +} diff --git a/src/vmm/src/windows/devices/virtio/balloon.rs b/src/vmm/src/windows/devices/virtio/balloon.rs new file mode 100644 index 000000000..ec3aa1250 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/balloon.rs @@ -0,0 +1,240 @@ +//! Virtio-balloon device (virtio spec v1.2 Section 5.5). +//! +//! Allows the host to request the guest to return or reclaim memory pages. +//! The guest driver inflates the balloon (returns pages) or deflates it +//! (reclaims pages) by sending page frame numbers on the respective queues. +//! +//! This implementation is protocol-only: inflate/deflate queues are processed +//! but no actual memory discard happens on the host side. Actual memory +//! reclamation would require extending `GuestMemoryAccessor` with a `discard()` +//! method (deferred to a future iteration). + +use super::mmio::VirtioDeviceBackend; +use super::queue::{GuestMemoryAccessor, Virtqueue}; + +/// Virtio device ID for balloon (spec 5.5). +const VIRTIO_ID_BALLOON: u32 = 5; + +/// VIRTIO_F_VERSION_1 — bit 32 (feature page 1, bit 0). +const VIRTIO_F_VERSION_1_PAGE1: u32 = 1; + +/// Maximum queue size for inflate/deflate queues. +const QUEUE_MAX_SIZE: u16 = 256; + +/// Inflate queue index (guest returns pages to host). +const INFLATE_QUEUE: u32 = 0; + +/// Deflate queue index (guest reclaims pages from host). +const DEFLATE_QUEUE: u32 = 1; + +/// Virtio-balloon backend. +/// +/// Config space layout (little-endian): +/// - offset 0: `num_pages` (u32) — target number of pages the balloon should hold. +/// - offset 4: `actual` (u32) — actual number of pages the balloon currently holds. +/// +/// The host sets `num_pages` to request inflation/deflation. +/// The guest writes `actual` to report current balloon size. +pub struct VirtioBalloon { + /// Target number of pages (set by host, read by guest). + num_pages: u32, + /// Actual number of pages (set by guest via config write). + actual: u32, +} + +impl VirtioBalloon { + pub fn new() -> Self { + VirtioBalloon { + num_pages: 0, + actual: 0, + } + } + + /// Set the target number of balloon pages (host API). + pub fn set_target_pages(&mut self, pages: u32) { + self.num_pages = pages; + } + + /// Write to config space at the given byte offset. + /// + /// The guest writes `actual` to offset 4 to report the current balloon size. + pub fn write_config(&mut self, offset: u64, value: u32) { + if offset == 4 { + self.actual = value; + } + // Writes to other offsets are silently ignored. + } +} + +impl VirtioDeviceBackend for VirtioBalloon { + fn device_id(&self) -> u32 { + VIRTIO_ID_BALLOON + } + + fn device_features(&self, page: u32) -> u32 { + match page { + 0 => 0, + 1 => VIRTIO_F_VERSION_1_PAGE1, + _ => 0, + } + } + + fn read_config(&self, offset: u64) -> u32 { + match offset { + 0 => self.num_pages, + 4 => self.actual, + _ => 0, + } + } + + fn write_config(&mut self, offset: u64, value: u32) { + // Delegate to the inherent method. + VirtioBalloon::write_config(self, offset, value); + } + + fn num_queues(&self) -> usize { + 2 // inflate + deflate + } + + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + QUEUE_MAX_SIZE + } + + fn queue_notify( + &mut self, + queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + let mut raised = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(e) => { + log::warn!("virtio-balloon: failed to read descriptor chain: {}", e); + break; + } + }; + + // Count page frame numbers (PFNs) in the chain. + // Each PFN is a u32 (4 bytes). The guest sends arrays of PFNs. + let mut pfn_count = 0u32; + for desc in &chain { + if desc.is_write() { + continue; // PFN buffers are device-readable. + } + pfn_count += desc.len / 4; + } + + match queue_idx { + INFLATE_QUEUE => { + // Guest is returning pages. In a full implementation, we would + // call madvise(MADV_DONTNEED) or equivalent on the host pages. + // For now, just track the count. + self.actual = self.actual.saturating_add(pfn_count); + log::trace!( + "virtio-balloon: inflate {} pages, actual={}", + pfn_count, + self.actual + ); + } + DEFLATE_QUEUE => { + // Guest is reclaiming pages. + self.actual = self.actual.saturating_sub(pfn_count); + log::trace!( + "virtio-balloon: deflate {} pages, actual={}", + pfn_count, + self.actual + ); + } + _ => {} + } + + if let Err(e) = queue.add_used(head, 0, mem) { + log::warn!("virtio-balloon: failed to add used buffer: {}", e); + break; + } + raised = true; + } + + raised + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_device_id() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.device_id(), 5); + } + + #[test] + fn test_num_queues() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.num_queues(), 2); + } + + #[test] + fn test_features_page0() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.device_features(0), 0); + } + + #[test] + fn test_features_page1_version_1() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.device_features(1), 1); // VIRTIO_F_VERSION_1 + } + + #[test] + fn test_features_page2_zero() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.device_features(2), 0); + } + + #[test] + fn test_config_defaults() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.read_config(0), 0); // num_pages + assert_eq!(balloon.read_config(4), 0); // actual + } + + #[test] + fn test_set_target_pages() { + let mut balloon = VirtioBalloon::new(); + balloon.set_target_pages(100); + assert_eq!(balloon.read_config(0), 100); + } + + #[test] + fn test_write_config_actual() { + let mut balloon = VirtioBalloon::new(); + balloon.write_config(4, 50); + assert_eq!(balloon.read_config(4), 50); + } + + #[test] + fn test_write_config_ignores_other_offsets() { + let mut balloon = VirtioBalloon::new(); + balloon.write_config(0, 999); // Should not change num_pages. + assert_eq!(balloon.read_config(0), 0); + } + + #[test] + fn test_read_config_unknown_offset() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.read_config(8), 0); + assert_eq!(balloon.read_config(12), 0); + } + + #[test] + fn test_queue_max_size() { + let balloon = VirtioBalloon::new(); + assert_eq!(balloon.queue_max_size(0), 256); + assert_eq!(balloon.queue_max_size(1), 256); + } +} diff --git a/src/vmm/src/windows/devices/virtio/block.rs b/src/vmm/src/windows/devices/virtio/block.rs new file mode 100644 index 000000000..31b5817ee --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/block.rs @@ -0,0 +1,994 @@ +//! Virtio-blk device backend (virtio spec v1.2 Section 5.2). +//! +//! Provides a file-backed block device that processes read, write, +//! and flush requests through the virtqueue. +//! +//! When a worker thread is started via `start_worker()`, disk I/O is +//! dispatched asynchronously to avoid blocking the vCPU loop. Without +//! a worker, requests are processed synchronously (fallback mode). + +use std::sync::mpsc; +use std::thread; + +use super::block_worker::{BlockCompletion, BlockRequest, BlockWorker, BufferDesc, RequestType}; +use super::disk::DiskBackend; +use super::mmio::VirtioDeviceBackend; +use super::queue::{Descriptor, GuestMemoryAccessor, Virtqueue}; + +/// Virtio device ID for block devices. +const VIRTIO_BLK_ID: u32 = 2; + +/// Block size in bytes (standard sector size). +const SECTOR_SIZE: u64 = 512; + +// Virtio-blk feature bits. +/// Device has a maximum size (not used for now). +#[allow(dead_code)] +const VIRTIO_BLK_F_SIZE_MAX: u32 = 1; +/// Device has a maximum segment size (not used for now). +#[allow(dead_code)] +const VIRTIO_BLK_F_SEG_MAX: u32 = 2; +/// Read-only device. +const VIRTIO_BLK_F_RO: u32 = 5; +/// VIRTIO_F_VERSION_1 — required for virtio 1.0+. +const VIRTIO_F_VERSION_1: u32 = 0; // Bit 32, goes in features page 1. + +// Virtio-blk request types. +const VIRTIO_BLK_T_IN: u32 = 0; // Read from disk. +const VIRTIO_BLK_T_OUT: u32 = 1; // Write to disk. +const VIRTIO_BLK_T_FLUSH: u32 = 4; // Flush. + +// Virtio-blk status values. +const VIRTIO_BLK_S_OK: u8 = 0; +const VIRTIO_BLK_S_IOERR: u8 = 1; +const VIRTIO_BLK_S_UNSUPP: u8 = 2; + +/// Virtio-blk device backed by a `DiskBackend`. +/// +/// Supports two modes: +/// - **Sync** (default): disk I/O in the vCPU thread (simple, but blocks) +/// - **Async** (after `start_worker()`): disk I/O in a dedicated thread +pub struct VirtioBlock { + /// Disk backend — owned here in sync mode, moved to worker in async mode. + disk: Option>, + capacity: u64, // In sectors. + read_only: bool, + + // --- Async worker fields (None until start_worker is called) --- + /// Channel to send requests to the worker thread. + request_tx: Option>, + /// Channel to receive completions from the worker thread. + completion_rx: Option>, + /// Worker thread join handle. + worker_handle: Option>, +} + +impl VirtioBlock { + /// Create a new virtio-blk device from a disk backend. + /// + /// `read_only` marks the device as read-only (rejects write requests). + pub fn new(disk: Box, read_only: bool) -> Self { + let capacity = disk.capacity_bytes() / SECTOR_SIZE; + VirtioBlock { + disk: Some(disk), + capacity, + read_only, + request_tx: None, + completion_rx: None, + worker_handle: None, + } + } + + /// Get disk capacity in sectors. + pub fn capacity(&self) -> u64 { + self.capacity + } + + /// Whether the async worker is active. + pub fn has_worker(&self) -> bool { + self.request_tx.is_some() + } + + /// Start the async block I/O worker thread (Plan B: WHPX-safe). + /// + /// Moves the disk backend to the worker. After this call, `queue_notify` + /// dispatches requests asynchronously instead of blocking. + /// + /// The worker thread never accesses guest memory — all guest memory + /// reads/writes happen on the vCPU thread (in queue_notify and + /// drain_completions). + pub fn start_worker(&mut self, name: &str) { + let disk = match self.disk.take() { + Some(d) => d, + None => { + log::warn!("start_worker called but disk already moved to worker"); + return; + } + }; + + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let worker = BlockWorker::new(req_rx, comp_tx, disk, self.read_only); + let handle = worker.run(name); + + self.request_tx = Some(req_tx); + self.completion_rx = Some(comp_rx); + self.worker_handle = Some(handle); + + log::info!("block worker '{}' started (Plan B)", name); + } + + /// Stop the worker thread and reclaim resources. + /// + /// Drops the request channel (worker exits on recv error), then joins. + pub fn stop_worker(&mut self) { + // Drop the sender to signal the worker to exit. + self.request_tx.take(); + self.completion_rx.take(); + + if let Some(handle) = self.worker_handle.take() { + let _ = handle.join(); + log::info!("block worker stopped"); + } + } + + /// Drain pending completions from the worker and update the used ring. + /// + /// Called from `tick_and_poll()` in the vCPU loop. Returns `true` if + /// any completions were processed (interrupt should be raised). + /// + /// **Plan B**: This method writes read data and status bytes to guest + /// memory on the vCPU thread, which is safe for WHPX. + pub fn drain_completions( + &mut self, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + let rx = match self.completion_rx { + Some(ref rx) => rx, + None => return false, + }; + + let mut drained = false; + while let Ok(comp) = rx.try_recv() { + // Write read data to guest memory (scatter to original buffer locations). + if let Some(ref read_data) = comp.read_data { + let mut data_offset: usize = 0; + for target in &comp.read_targets { + let end = data_offset + target.len as usize; + if end <= read_data.len() { + let _ = mem.write_at(target.addr, &read_data[data_offset..end]); + } + data_offset = end; + } + } + + // Write status byte to guest memory. + let _ = mem.write_at(comp.status_addr, &[comp.status]); + + let _ = queue.add_used(comp.head_index, comp.bytes_written, mem); + drained = true; + } + drained + } + + /// Parse a descriptor chain header and build a BlockRequest. + /// + /// For write requests (Plan B), pre-reads data from guest memory + /// into the request's `write_data` field so the worker thread + /// never needs to access guest memory. + /// + /// Returns None if the chain is malformed. + fn parse_request( + chain: &[Descriptor], + head_index: u16, + mem: &dyn GuestMemoryAccessor, + ) -> Option { + if chain.len() < 2 { + log::debug!("BLK: short chain len={}", chain.len()); + return None; + } + + let header_desc = &chain[0]; + if header_desc.len < 16 { + log::debug!("BLK: short header len={}", header_desc.len); + return None; + } + + let mut header_buf = [0u8; 16]; + if mem.read_at(header_desc.addr, &mut header_buf).is_err() { + log::debug!("BLK: header read failed addr=0x{:X}", header_desc.addr); + return None; + } + + let raw_type = + u32::from_le_bytes([header_buf[0], header_buf[1], header_buf[2], header_buf[3]]); + let sector = u64::from_le_bytes([ + header_buf[8], + header_buf[9], + header_buf[10], + header_buf[11], + header_buf[12], + header_buf[13], + header_buf[14], + header_buf[15], + ]); + + let req_type = match raw_type { + VIRTIO_BLK_T_IN => RequestType::Read, + VIRTIO_BLK_T_OUT => RequestType::Write, + VIRTIO_BLK_T_FLUSH => RequestType::Flush, + _ => RequestType::Unsupported, + }; + + // Middle descriptors: data buffers. Last descriptor: status byte. + let data_descs = &chain[1..chain.len() - 1]; + let status_desc = chain.last().unwrap(); + + let data_buffers: Vec = data_descs + .iter() + .map(|d| BufferDesc { + addr: d.addr, + len: d.len, + is_write: d.is_write(), + }) + .collect(); + + // Plan B: For write requests, pre-read data from guest memory + // so the worker thread never needs guest memory access. + let write_data = if req_type == RequestType::Write { + let mut all_data = Vec::new(); + for desc in data_descs { + if !desc.is_write() { + // Device-readable buffer: contains data to write to disk. + let mut buf = vec![0u8; desc.len as usize]; + if mem.read_at(desc.addr, &mut buf).is_err() { + log::debug!("BLK: pre-read write data failed addr=0x{:X}", desc.addr); + return None; + } + all_data.extend_from_slice(&buf); + } + } + Some(all_data) + } else { + None + }; + + Some(BlockRequest { + head_index, + req_type, + sector, + data_buffers, + status_addr: status_desc.addr, + write_data, + }) + } + + // --- Synchronous fallback (used when no worker is active) --- + + /// Process a single virtio-blk request from a descriptor chain. + fn process_request(&mut self, chain: &[Descriptor], mem: &dyn GuestMemoryAccessor) -> u8 { + // Minimum: header + status (flush has no data descriptor). + if chain.len() < 2 { + log::debug!("BLK: short chain len={}", chain.len()); + return VIRTIO_BLK_S_IOERR; + } + + // First descriptor: request header (device-readable). + let header_desc = &chain[0]; + if header_desc.len < 16 { + log::debug!("BLK: short header len={}", header_desc.len); + return VIRTIO_BLK_S_IOERR; + } + + let mut header_buf = [0u8; 16]; + if mem.read_at(header_desc.addr, &mut header_buf).is_err() { + log::debug!("BLK: header read failed addr=0x{:X}", header_desc.addr); + return VIRTIO_BLK_S_IOERR; + } + + let req_type = + u32::from_le_bytes([header_buf[0], header_buf[1], header_buf[2], header_buf[3]]); + let sector = u64::from_le_bytes([ + header_buf[8], + header_buf[9], + header_buf[10], + header_buf[11], + header_buf[12], + header_buf[13], + header_buf[14], + header_buf[15], + ]); + + // Middle descriptors: data buffer(s) (may be empty for flush). + // Last descriptor: status byte (device-writable). + let data_descs = &chain[1..chain.len() - 1]; + + match req_type { + VIRTIO_BLK_T_IN => { + if data_descs.is_empty() { + return VIRTIO_BLK_S_IOERR; + } + self.handle_read(sector, data_descs, mem) + } + VIRTIO_BLK_T_OUT => { + if data_descs.is_empty() { + return VIRTIO_BLK_S_IOERR; + } + self.handle_write(sector, data_descs, mem) + } + VIRTIO_BLK_T_FLUSH => self.handle_flush(), + _ => VIRTIO_BLK_S_UNSUPP, + } + } + + fn handle_read( + &mut self, + sector: u64, + data_descs: &[Descriptor], + mem: &dyn GuestMemoryAccessor, + ) -> u8 { + let disk = match self.disk { + Some(ref mut d) => d, + None => return VIRTIO_BLK_S_IOERR, + }; + let mut offset = sector * SECTOR_SIZE; + + for (i, desc) in data_descs.iter().enumerate() { + if !desc.is_write() { + log::debug!( + "BLK READ: desc[{}] not writable, flags=0x{:X}", + i, + desc.flags + ); + return VIRTIO_BLK_S_IOERR; + } + let mut buf = vec![0u8; desc.len as usize]; + if let Err(e) = disk.read_at(offset, &mut buf) { + log::debug!( + "BLK READ: disk.read_at(0x{:X}, {}) failed: {}", + offset, + desc.len, + e + ); + return VIRTIO_BLK_S_IOERR; + } + if let Err(e) = mem.write_at(desc.addr, &buf) { + log::debug!( + "BLK READ: mem.write_at(0x{:X}, {}) failed: {}", + desc.addr, + buf.len(), + e + ); + return VIRTIO_BLK_S_IOERR; + } + offset += desc.len as u64; + } + VIRTIO_BLK_S_OK + } + + fn handle_write( + &mut self, + sector: u64, + data_descs: &[Descriptor], + mem: &dyn GuestMemoryAccessor, + ) -> u8 { + if self.read_only { + return VIRTIO_BLK_S_IOERR; + } + + let disk = match self.disk { + Some(ref mut d) => d, + None => return VIRTIO_BLK_S_IOERR, + }; + let mut offset = sector * SECTOR_SIZE; + + for desc in data_descs { + if desc.is_write() { + return VIRTIO_BLK_S_IOERR; // Data buffer must be device-readable for writes. + } + let mut buf = vec![0u8; desc.len as usize]; + if mem.read_at(desc.addr, &mut buf).is_err() { + return VIRTIO_BLK_S_IOERR; + } + if disk.write_at(offset, &buf).is_err() { + return VIRTIO_BLK_S_IOERR; + } + offset += desc.len as u64; + } + VIRTIO_BLK_S_OK + } + + fn handle_flush(&mut self) -> u8 { + let disk = match self.disk { + Some(ref mut d) => d, + None => return VIRTIO_BLK_S_IOERR, + }; + if disk.flush().is_err() { + VIRTIO_BLK_S_IOERR + } else { + VIRTIO_BLK_S_OK + } + } +} + +impl Drop for VirtioBlock { + fn drop(&mut self) { + self.stop_worker(); + } +} + +impl VirtioDeviceBackend for VirtioBlock { + fn device_id(&self) -> u32 { + VIRTIO_BLK_ID + } + + fn device_features(&self, page: u32) -> u32 { + match page { + 0 => { + let mut features = 0u32; + if self.read_only { + features |= 1 << VIRTIO_BLK_F_RO; + } + features + } + 1 => 1 << VIRTIO_F_VERSION_1, // VIRTIO_F_VERSION_1 is bit 32 (page 1, bit 0). + _ => 0, + } + } + + fn read_config(&self, offset: u64) -> u32 { + // Config space: capacity (u64 at offset 0). + match offset { + 0 => self.capacity as u32, // Low 32 bits. + 4 => (self.capacity >> 32) as u32, // High 32 bits. + _ => 0, + } + } + + fn queue_notify( + &mut self, + _queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + // Async mode: parse descriptors and dispatch to worker. + if let Some(ref tx) = self.request_tx { + let mut dispatched = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = queue.add_used(head, 0, mem); + dispatched = true; + continue; + } + }; + + match Self::parse_request(&chain, head, mem) { + Some(req) => { + if tx.send(req).is_err() { + // Worker died — fall through with error status. + if let Some(status_desc) = chain.last() { + let _ = mem.write_at(status_desc.addr, &[VIRTIO_BLK_S_IOERR]); + } + let total_written: u32 = + chain.iter().filter(|d| d.is_write()).map(|d| d.len).sum(); + let _ = queue.add_used(head, total_written, mem); + dispatched = true; + } + } + None => { + // Malformed chain — write error status directly. + let _ = queue.add_used(head, 0, mem); + dispatched = true; + } + } + } + + // In async mode, don't raise interrupt here — completions + // arrive via drain_completions() during tick_and_poll(). + // Return dispatched for malformed chains that were handled inline. + dispatched + } else { + // Sync fallback: process requests directly (original behavior). + let mut processed = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + }; + + let status = self.process_request(&chain, mem); + + if let Some(status_desc) = chain.last() { + let _ = mem.write_at(status_desc.addr, &[status]); + } + + let total_written: u32 = chain.iter().filter(|d| d.is_write()).map(|d| d.len).sum(); + let _ = queue.add_used(head, total_written, mem); + processed = true; + } + + processed + } + } + + fn drain_completions( + &mut self, + queues: &mut [Virtqueue], + mem: &dyn GuestMemoryAccessor, + ) -> bool { + if let Some(queue) = queues.first_mut() { + self.drain_completions(queue, mem) + } else { + false + } + } + + fn num_queues(&self) -> usize { + 1 // Virtio-blk uses a single request queue. + } + + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + 256 + } +} + +#[cfg(test)] +mod tests { + use super::super::super::error::WkrunError; + use super::disk::RawDiskBackend; + use super::*; + use std::cell::RefCell; + use std::fs::File; + use std::io::Write as IoWrite; + use tempfile::NamedTempFile; + + struct MockMem { + data: RefCell>, + } + + impl MockMem { + fn new(size: usize) -> Self { + MockMem { + data: RefCell::new(vec![0u8; size]), + } + } + + fn write_bytes(&self, addr: u64, bytes: &[u8]) { + let a = addr as usize; + let mut data = self.data.borrow_mut(); + data[a..a + bytes.len()].copy_from_slice(bytes); + } + + fn read_bytes(&self, addr: u64, len: usize) -> Vec { + let a = addr as usize; + let data = self.data.borrow(); + data[a..a + len].to_vec() + } + } + + impl GuestMemoryAccessor for MockMem { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> super::super::super::error::Result<()> { + let a = addr as usize; + let data = self.data.borrow(); + if a + buf.len() > data.len() { + return Err(WkrunError::Memory("out of bounds".into())); + } + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + fn write_at(&self, addr: u64, data: &[u8]) -> super::super::super::error::Result<()> { + let a = addr as usize; + let mut mem = self.data.borrow_mut(); + if a + data.len() > mem.len() { + return Err(WkrunError::Memory("out of bounds".into())); + } + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + fn create_test_disk(sectors: u64) -> NamedTempFile { + let mut f = NamedTempFile::new().unwrap(); + let data = vec![0u8; (sectors * SECTOR_SIZE) as usize]; + f.write_all(&data).unwrap(); + f.flush().unwrap(); + f + } + + fn create_disk_with_pattern(sectors: u64) -> NamedTempFile { + let mut f = NamedTempFile::new().unwrap(); + for sector in 0..sectors { + let pattern = vec![(sector & 0xFF) as u8; SECTOR_SIZE as usize]; + f.write_all(&pattern).unwrap(); + } + f.flush().unwrap(); + f + } + + fn open_raw_backend(tmp: &NamedTempFile, read_only: bool) -> Box { + let file = File::options() + .read(true) + .write(!read_only) + .open(tmp.path()) + .unwrap(); + Box::new(RawDiskBackend::new(file).unwrap()) + } + + // --- Construction --- + + #[test] + fn test_new_block_device() { + let tmp = create_test_disk(8); + let backend = open_raw_backend(&tmp, false); + let blk = VirtioBlock::new(backend, false); + assert_eq!(blk.capacity(), 8); + assert_eq!(blk.device_id(), VIRTIO_BLK_ID); + assert!(!blk.has_worker()); + } + + #[test] + fn test_empty_disk_error() { + let tmp = NamedTempFile::new().unwrap(); + let file = File::open(tmp.path()).unwrap(); + assert!(RawDiskBackend::new(file).is_err()); + } + + #[test] + fn test_read_only_features() { + let tmp = create_test_disk(1); + let backend = open_raw_backend(&tmp, true); + let blk = VirtioBlock::new(backend, true); + let features = blk.device_features(0); + assert_ne!(features & (1 << VIRTIO_BLK_F_RO), 0); + } + + // --- Config space --- + + #[test] + fn test_config_capacity() { + let tmp = create_test_disk(1024); + let backend = open_raw_backend(&tmp, false); + let blk = VirtioBlock::new(backend, false); + assert_eq!(blk.read_config(0), 1024); // Low. + assert_eq!(blk.read_config(4), 0); // High. + } + + // --- Request processing (direct/sync) --- + + #[test] + fn test_read_request() { + let tmp = create_disk_with_pattern(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, false); + let mem = MockMem::new(0x10000); + + // Write request header: type=IN, sector=2. + let mut header = [0u8; 16]; + header[0..4].copy_from_slice(&VIRTIO_BLK_T_IN.to_le_bytes()); + header[8..16].copy_from_slice(&2u64.to_le_bytes()); + mem.write_bytes(0x1000, &header); + + // Build descriptor chain. + let chain = vec![ + Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }, // Header (device-readable). + Descriptor { + addr: 0x2000, + len: 512, + flags: 2, + next: 0, + }, // Data (device-writable). + Descriptor { + addr: 0x3000, + len: 1, + flags: 2, + next: 0, + }, // Status (device-writable). + ]; + + let status = blk.process_request(&chain, &mem); + assert_eq!(status, VIRTIO_BLK_S_OK); + + // Check that data was read (sector 2 pattern = 0x02). + let data = mem.read_bytes(0x2000, 512); + assert!(data.iter().all(|&b| b == 0x02)); + } + + #[test] + fn test_write_request() { + let tmp = create_test_disk(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, false); + let mem = MockMem::new(0x10000); + + // Header: type=OUT, sector=1. + let mut header = [0u8; 16]; + header[0..4].copy_from_slice(&VIRTIO_BLK_T_OUT.to_le_bytes()); + header[8..16].copy_from_slice(&1u64.to_le_bytes()); + mem.write_bytes(0x1000, &header); + + // Data to write (device-readable). + let write_data = vec![0xABu8; 512]; + mem.write_bytes(0x2000, &write_data); + + let chain = vec![ + Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x2000, + len: 512, + flags: 0, + next: 0, + }, // Device-readable. + Descriptor { + addr: 0x3000, + len: 1, + flags: 2, + next: 0, + }, // Status. + ]; + + let status = blk.process_request(&chain, &mem); + assert_eq!(status, VIRTIO_BLK_S_OK); + + // Verify by reading back. + let mut header2 = [0u8; 16]; + header2[0..4].copy_from_slice(&VIRTIO_BLK_T_IN.to_le_bytes()); + header2[8..16].copy_from_slice(&1u64.to_le_bytes()); + mem.write_bytes(0x4000, &header2); + + let read_chain = vec![ + Descriptor { + addr: 0x4000, + len: 16, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x5000, + len: 512, + flags: 2, + next: 0, + }, + Descriptor { + addr: 0x6000, + len: 1, + flags: 2, + next: 0, + }, + ]; + + let status2 = blk.process_request(&read_chain, &mem); + assert_eq!(status2, VIRTIO_BLK_S_OK); + let readback = mem.read_bytes(0x5000, 512); + assert!(readback.iter().all(|&b| b == 0xAB)); + } + + #[test] + fn test_write_rejected_on_read_only() { + let tmp = create_test_disk(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, true); + let mem = MockMem::new(0x10000); + + let mut header = [0u8; 16]; + header[0..4].copy_from_slice(&VIRTIO_BLK_T_OUT.to_le_bytes()); + mem.write_bytes(0x1000, &header); + + let chain = vec![ + Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x2000, + len: 512, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x3000, + len: 1, + flags: 2, + next: 0, + }, + ]; + + let status = blk.process_request(&chain, &mem); + assert_eq!(status, VIRTIO_BLK_S_IOERR); + } + + #[test] + fn test_flush_request() { + let tmp = create_test_disk(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, false); + let mem = MockMem::new(0x10000); + + let mut header = [0u8; 16]; + header[0..4].copy_from_slice(&VIRTIO_BLK_T_FLUSH.to_le_bytes()); + mem.write_bytes(0x1000, &header); + + let chain = vec![ + Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x3000, + len: 1, + flags: 2, + next: 0, + }, + ]; + + let status = blk.process_request(&chain, &mem); + assert_eq!(status, VIRTIO_BLK_S_OK); + } + + #[test] + fn test_unsupported_request_type() { + let tmp = create_test_disk(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, false); + let mem = MockMem::new(0x10000); + + let mut header = [0u8; 16]; + header[0..4].copy_from_slice(&99u32.to_le_bytes()); // Unknown type. + mem.write_bytes(0x1000, &header); + + let chain = vec![ + Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x2000, + len: 512, + flags: 2, + next: 0, + }, + Descriptor { + addr: 0x3000, + len: 1, + flags: 2, + next: 0, + }, + ]; + + let status = blk.process_request(&chain, &mem); + assert_eq!(status, VIRTIO_BLK_S_UNSUPP); + } + + #[test] + fn test_short_chain_error() { + let tmp = create_test_disk(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, false); + let mem = MockMem::new(0x10000); + + let chain = vec![Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }]; + + let status = blk.process_request(&chain, &mem); + assert_eq!(status, VIRTIO_BLK_S_IOERR); + } + + // --- VirtioDeviceBackend trait --- + + #[test] + fn test_version_1_feature() { + let tmp = create_test_disk(1); + let backend = open_raw_backend(&tmp, false); + let blk = VirtioBlock::new(backend, false); + let features_page1 = blk.device_features(1); + assert_eq!(features_page1, 1); // Bit 0 of page 1 = VIRTIO_F_VERSION_1. + } + + #[test] + fn test_num_queues() { + let tmp = create_test_disk(1); + let backend = open_raw_backend(&tmp, false); + let blk = VirtioBlock::new(backend, false); + assert_eq!(blk.num_queues(), 1); + } + + #[test] + fn test_queue_max_size() { + let tmp = create_test_disk(1); + let backend = open_raw_backend(&tmp, false); + let blk = VirtioBlock::new(backend, false); + assert_eq!(blk.queue_max_size(0), 256); + } + + // --- parse_request --- + + #[test] + fn test_parse_request_read() { + let mem = MockMem::new(0x10000); + + let mut header = [0u8; 16]; + header[0..4].copy_from_slice(&VIRTIO_BLK_T_IN.to_le_bytes()); + header[8..16].copy_from_slice(&5u64.to_le_bytes()); + mem.write_bytes(0x1000, &header); + + let chain = vec![ + Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }, + Descriptor { + addr: 0x2000, + len: 512, + flags: 2, + next: 0, + }, + Descriptor { + addr: 0x3000, + len: 1, + flags: 2, + next: 0, + }, + ]; + + let req = VirtioBlock::parse_request(&chain, 10, &mem).unwrap(); + assert_eq!(req.head_index, 10); + assert_eq!(req.req_type, RequestType::Read); + assert_eq!(req.sector, 5); + assert_eq!(req.data_buffers.len(), 1); + assert!(req.data_buffers[0].is_write); + assert_eq!(req.status_addr, 0x3000); + } + + #[test] + fn test_parse_request_short_chain_returns_none() { + let mem = MockMem::new(0x10000); + let chain = vec![Descriptor { + addr: 0x1000, + len: 16, + flags: 0, + next: 0, + }]; + assert!(VirtioBlock::parse_request(&chain, 0, &mem).is_none()); + } + + // --- stop_worker --- + + #[test] + fn test_stop_worker_without_start_is_noop() { + let tmp = create_test_disk(4); + let backend = open_raw_backend(&tmp, false); + let mut blk = VirtioBlock::new(backend, false); + blk.stop_worker(); // Should not panic. + } +} diff --git a/src/vmm/src/windows/devices/virtio/block_worker.rs b/src/vmm/src/windows/devices/virtio/block_worker.rs new file mode 100644 index 000000000..30f46e536 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/block_worker.rs @@ -0,0 +1,667 @@ +//! Async block I/O worker thread for virtio-blk. +//! +//! Moves disk I/O off the vCPU loop into a dedicated thread so that +//! long-running reads/writes don't starve vsock or net devices. +//! +//! **Plan B (WHPX-safe)**: The worker thread NEVER accesses guest memory. +//! - For reads: worker reads disk → Vec, sends Vec in completion. +//! The vCPU thread writes the data to guest memory. +//! - For writes: vCPU thread pre-reads data from guest memory → Vec, +//! sends Vec in request. Worker writes Vec to disk. +//! - Status byte is always written by the vCPU thread. +//! +//! This avoids WHPX memory coherence issues where non-vCPU thread +//! writes to guest memory cause ~60% boot failure on Win10. + +use std::sync::mpsc; +use std::thread; + +use super::disk::DiskBackend; + +/// Block size in bytes (standard sector size). +const SECTOR_SIZE: u64 = 512; + +// Virtio-blk status values. +const VIRTIO_BLK_S_OK: u8 = 0; +const VIRTIO_BLK_S_IOERR: u8 = 1; +const VIRTIO_BLK_S_UNSUPP: u8 = 2; + +/// Type of block request. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RequestType { + Read, + Write, + Flush, + Unsupported, +} + +/// A single buffer descriptor from the virtqueue chain. +/// +/// Used in completions to tell the vCPU thread where to write read data. +#[derive(Debug, Clone)] +pub struct BufferDesc { + /// Guest physical address. + pub addr: u64, + /// Length in bytes. + pub len: u32, + /// Whether this buffer is device-writable (guest reads from it). + pub is_write: bool, +} + +/// A block request dispatched from the vCPU thread to the worker. +#[derive(Debug)] +pub struct BlockRequest { + /// Descriptor chain head index (for add_used later). + pub head_index: u16, + /// Request type (read/write/flush). + pub req_type: RequestType, + /// Starting sector (for read/write). + pub sector: u64, + /// Data buffer descriptors (between header and status). + /// For reads: describes where vCPU thread should write the returned data. + /// For writes: only used for metadata (the actual data is in write_data). + pub data_buffers: Vec, + /// Guest address of the status byte (last descriptor). + pub status_addr: u64, + /// Pre-read write data from guest memory (only for Write requests). + /// The vCPU thread reads this from guest memory before sending. + pub write_data: Option>, +} + +/// Completion sent from the worker back to the vCPU thread. +#[derive(Debug)] +pub struct BlockCompletion { + /// Descriptor chain head index. + pub head_index: u16, + /// Total bytes written to device-writable descriptors (for used ring). + pub bytes_written: u32, + /// Virtio-blk status byte (OK/IOERR/UNSUPP). + pub status: u8, + /// Guest address where status byte should be written. + pub status_addr: u64, + /// Data read from disk (only for Read requests). + /// The vCPU thread writes this to guest memory at the addresses + /// specified in read_targets. + pub read_data: Option>, + /// Guest memory targets for read data (addr, len pairs from data_buffers). + /// The vCPU thread iterates these to scatter read_data into guest memory. + pub read_targets: Vec, +} + +/// Worker thread that processes block I/O requests. +/// +/// The worker NEVER accesses guest memory. All guest memory reads/writes +/// are done by the vCPU thread (Plan B for WHPX safety). +pub struct BlockWorker { + request_rx: mpsc::Receiver, + completion_tx: mpsc::Sender, + disk: Box, + read_only: bool, +} + +impl BlockWorker { + /// Create a new block worker. + pub fn new( + request_rx: mpsc::Receiver, + completion_tx: mpsc::Sender, + disk: Box, + read_only: bool, + ) -> Self { + BlockWorker { + request_rx, + completion_tx, + disk, + read_only, + } + } + + /// Spawn the worker on a named thread. Returns the join handle. + pub fn run(self, name: &str) -> thread::JoinHandle<()> { + let thread_name = name.to_string(); + thread::Builder::new() + .name(thread_name) + .spawn(move || self.work()) + .expect("failed to spawn block worker thread") + } + + /// Blocking recv loop: process requests until the channel closes. + fn work(mut self) { + log::info!("block worker started (Plan B: no guest memory access)"); + + while let Ok(req) = self.request_rx.recv() { + let completion = self.process_request(req); + + // If the vCPU thread dropped its receiver, the VM is shutting down. + if self.completion_tx.send(completion).is_err() { + break; + } + } + + log::info!("block worker exiting"); + } + + /// Process a single block request. Returns a completion with data/status. + fn process_request(&mut self, req: BlockRequest) -> BlockCompletion { + match req.req_type { + RequestType::Read => self.handle_read(req), + RequestType::Write => self.handle_write(req), + RequestType::Flush => self.handle_flush(req), + RequestType::Unsupported => BlockCompletion { + head_index: req.head_index, + bytes_written: 0, + status: VIRTIO_BLK_S_UNSUPP, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }, + } + } + + fn handle_read(&mut self, req: BlockRequest) -> BlockCompletion { + let mut offset = req.sector * SECTOR_SIZE; + let mut all_data = Vec::new(); + let mut bytes_written: u32 = 0; + + for buf in &req.data_buffers { + if !buf.is_write { + log::debug!("BLK worker READ: buffer not device-writable"); + return BlockCompletion { + head_index: req.head_index, + bytes_written, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + let mut data = vec![0u8; buf.len as usize]; + if let Err(e) = self.disk.read_at(offset, &mut data) { + log::debug!("BLK worker READ: disk.read_at failed: {}", e); + return BlockCompletion { + head_index: req.head_index, + bytes_written, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + all_data.extend_from_slice(&data); + offset += buf.len as u64; + bytes_written += buf.len; + } + + // +1 for the status byte (also device-writable). + BlockCompletion { + head_index: req.head_index, + bytes_written: bytes_written + 1, + status: VIRTIO_BLK_S_OK, + status_addr: req.status_addr, + read_data: Some(all_data), + read_targets: req.data_buffers, + } + } + + fn handle_write(&mut self, req: BlockRequest) -> BlockCompletion { + if self.read_only { + return BlockCompletion { + head_index: req.head_index, + bytes_written: 0, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + + let write_data = match req.write_data { + Some(ref data) => data, + None => { + log::debug!("BLK worker WRITE: no write_data provided"); + return BlockCompletion { + head_index: req.head_index, + bytes_written: 0, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + }; + + let mut offset = req.sector * SECTOR_SIZE; + let mut data_offset: usize = 0; + + for buf in &req.data_buffers { + if buf.is_write { + // Data for write must be device-readable (not device-writable). + return BlockCompletion { + head_index: req.head_index, + bytes_written: 0, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + let end = data_offset + buf.len as usize; + if end > write_data.len() { + log::debug!("BLK worker WRITE: write_data too short"); + return BlockCompletion { + head_index: req.head_index, + bytes_written: 0, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + if self + .disk + .write_at(offset, &write_data[data_offset..end]) + .is_err() + { + return BlockCompletion { + head_index: req.head_index, + bytes_written: 0, + status: VIRTIO_BLK_S_IOERR, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + }; + } + offset += buf.len as u64; + data_offset = end; + } + + // Only status byte is device-writable for writes. + BlockCompletion { + head_index: req.head_index, + bytes_written: 1, + status: VIRTIO_BLK_S_OK, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + } + } + + fn handle_flush(&mut self, req: BlockRequest) -> BlockCompletion { + let status = if self.disk.flush().is_err() { + VIRTIO_BLK_S_IOERR + } else { + VIRTIO_BLK_S_OK + }; + // bytes_written=1 for the status byte (device-writable), + // matching the sync path in VirtioBlock::queue_notify. + BlockCompletion { + head_index: req.head_index, + bytes_written: 1, + status, + status_addr: req.status_addr, + read_data: None, + read_targets: vec![], + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::mpsc; + + /// In-memory disk backend for testing. + struct MemDisk { + data: Vec, + read_only: bool, + } + + impl MemDisk { + fn new(size: usize) -> Self { + MemDisk { + data: vec![0u8; size], + read_only: false, + } + } + + fn with_pattern(sectors: u64) -> Self { + let size = (sectors * SECTOR_SIZE) as usize; + let mut data = vec![0u8; size]; + for sector in 0..sectors { + let start = (sector * SECTOR_SIZE) as usize; + let end = start + SECTOR_SIZE as usize; + data[start..end].fill((sector & 0xFF) as u8); + } + MemDisk { + data, + read_only: false, + } + } + } + + // Safety: MemDisk only uses Vec which is Send. + unsafe impl Send for MemDisk {} + + impl DiskBackend for MemDisk { + fn read_at( + &mut self, + offset: u64, + buf: &mut [u8], + ) -> super::super::super::super::error::Result<()> { + let start = offset as usize; + let end = start + buf.len(); + if end > self.data.len() { + return Err(super::super::super::super::error::WkrunError::Device( + "read out of bounds".into(), + )); + } + buf.copy_from_slice(&self.data[start..end]); + Ok(()) + } + + fn write_at( + &mut self, + offset: u64, + buf: &[u8], + ) -> super::super::super::super::error::Result<()> { + if self.read_only { + return Err(super::super::super::super::error::WkrunError::Device( + "read-only disk".into(), + )); + } + let start = offset as usize; + let end = start + buf.len(); + if end > self.data.len() { + return Err(super::super::super::super::error::WkrunError::Device( + "write out of bounds".into(), + )); + } + self.data[start..end].copy_from_slice(buf); + Ok(()) + } + + fn flush(&mut self) -> super::super::super::super::error::Result<()> { + Ok(()) + } + + fn capacity_bytes(&self) -> u64 { + self.data.len() as u64 + } + } + + #[test] + fn test_worker_read_request() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::with_pattern(4); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-read"); + + // Send a read request for sector 2 (pattern = 0x02). + req_tx + .send(BlockRequest { + head_index: 42, + req_type: RequestType::Read, + sector: 2, + data_buffers: vec![BufferDesc { + addr: 0x2000, + len: 512, + is_write: true, + }], + status_addr: 0x3000, + write_data: None, + }) + .unwrap(); + + // Close the channel to let the worker exit. + drop(req_tx); + handle.join().unwrap(); + + // Check completion. + let comp = comp_rx.recv().unwrap(); + assert_eq!(comp.head_index, 42); + assert_eq!(comp.bytes_written, 513); // 512 data + 1 status + assert_eq!(comp.status, VIRTIO_BLK_S_OK); + assert_eq!(comp.status_addr, 0x3000); + + // Verify read data is returned in completion (not written to guest mem). + let read_data = comp.read_data.unwrap(); + assert_eq!(read_data.len(), 512); + assert!(read_data.iter().all(|&b| b == 0x02)); + + // Verify read targets match the original buffers. + assert_eq!(comp.read_targets.len(), 1); + assert_eq!(comp.read_targets[0].addr, 0x2000); + assert_eq!(comp.read_targets[0].len, 512); + } + + #[test] + fn test_worker_write_request() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::new(2048); + + // Write data is pre-read from guest memory by the vCPU thread. + let write_data = vec![0xAB; 512]; + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-write"); + + req_tx + .send(BlockRequest { + head_index: 7, + req_type: RequestType::Write, + sector: 1, + data_buffers: vec![BufferDesc { + addr: 0x2000, + len: 512, + is_write: false, // Device-readable for writes. + }], + status_addr: 0x3000, + write_data: Some(write_data), + }) + .unwrap(); + + drop(req_tx); + handle.join().unwrap(); + + let comp = comp_rx.recv().unwrap(); + assert_eq!(comp.head_index, 7); + assert_eq!(comp.bytes_written, 1); // Only status byte is writable. + assert_eq!(comp.status, VIRTIO_BLK_S_OK); + assert!(comp.read_data.is_none()); + } + + #[test] + fn test_worker_flush_request() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::new(1024); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-flush"); + + req_tx + .send(BlockRequest { + head_index: 3, + req_type: RequestType::Flush, + sector: 0, + data_buffers: vec![], + status_addr: 0x3000, + write_data: None, + }) + .unwrap(); + + drop(req_tx); + handle.join().unwrap(); + + let comp = comp_rx.recv().unwrap(); + assert_eq!(comp.head_index, 3); + assert_eq!(comp.bytes_written, 1); // status byte + assert_eq!(comp.status, VIRTIO_BLK_S_OK); + } + + #[test] + fn test_worker_unsupported_request() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::new(1024); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-unsupp"); + + req_tx + .send(BlockRequest { + head_index: 5, + req_type: RequestType::Unsupported, + sector: 0, + data_buffers: vec![], + status_addr: 0x3000, + write_data: None, + }) + .unwrap(); + + drop(req_tx); + handle.join().unwrap(); + + let comp = comp_rx.recv().unwrap(); + assert_eq!(comp.head_index, 5); + assert_eq!(comp.bytes_written, 0); + assert_eq!(comp.status, VIRTIO_BLK_S_UNSUPP); + } + + #[test] + fn test_worker_multiple_requests() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::with_pattern(8); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-multi"); + + // Send 3 read requests for sectors 0, 1, 2. + for i in 0..3u16 { + req_tx + .send(BlockRequest { + head_index: i, + req_type: RequestType::Read, + sector: i as u64, + data_buffers: vec![BufferDesc { + addr: 0x2000 + (i as u64) * 0x1000, + len: 512, + is_write: true, + }], + status_addr: 0x8000 + i as u64, + write_data: None, + }) + .unwrap(); + } + + drop(req_tx); + handle.join().unwrap(); + + // All 3 completions should arrive. + let mut completions: Vec = Vec::new(); + while let Ok(c) = comp_rx.try_recv() { + completions.push(c); + } + assert_eq!(completions.len(), 3); + + // Verify each sector's data is in the completion. + for (idx, comp) in completions.iter().enumerate() { + let data = comp.read_data.as_ref().unwrap(); + assert_eq!(data.len(), 512); + assert!( + data.iter().all(|&b| b == idx as u8), + "sector {} data mismatch", + idx + ); + } + } + + #[test] + fn test_worker_read_only_rejects_write() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::new(1024); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), true); + let handle = worker.run("test-blk-ro"); + + req_tx + .send(BlockRequest { + head_index: 1, + req_type: RequestType::Write, + sector: 0, + data_buffers: vec![BufferDesc { + addr: 0x2000, + len: 512, + is_write: false, + }], + status_addr: 0x3000, + write_data: Some(vec![0xAB; 512]), + }) + .unwrap(); + + drop(req_tx); + handle.join().unwrap(); + + let comp = comp_rx.recv().unwrap(); + assert_eq!(comp.bytes_written, 0); + assert_eq!(comp.status, VIRTIO_BLK_S_IOERR); + } + + #[test] + fn test_worker_graceful_shutdown_on_channel_close() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, _comp_rx) = mpsc::channel(); + + let disk = MemDisk::new(1024); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-shutdown"); + + // Drop the sender — worker should exit gracefully. + drop(req_tx); + handle.join().unwrap(); // Should not hang or panic. + } + + #[test] + fn test_worker_write_missing_data_returns_error() { + let (req_tx, req_rx) = mpsc::channel(); + let (comp_tx, comp_rx) = mpsc::channel(); + + let disk = MemDisk::new(2048); + + let worker = BlockWorker::new(req_rx, comp_tx, Box::new(disk), false); + let handle = worker.run("test-blk-write-nodata"); + + // Write request without write_data should fail. + req_tx + .send(BlockRequest { + head_index: 1, + req_type: RequestType::Write, + sector: 0, + data_buffers: vec![BufferDesc { + addr: 0x2000, + len: 512, + is_write: false, + }], + status_addr: 0x3000, + write_data: None, // Missing! + }) + .unwrap(); + + drop(req_tx); + handle.join().unwrap(); + + let comp = comp_rx.recv().unwrap(); + assert_eq!(comp.status, VIRTIO_BLK_S_IOERR); + } +} diff --git a/src/vmm/src/windows/devices/virtio/disk.rs b/src/vmm/src/windows/devices/virtio/disk.rs new file mode 100644 index 000000000..f33a6f62a --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/disk.rs @@ -0,0 +1,1298 @@ +//! Disk backend abstraction for virtio-blk. +//! +//! Provides a `DiskBackend` trait to abstract block device I/O, +//! with implementations for raw files and qcow2 images. + +use std::fs::File; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use super::super::super::error::{Result, WkrunError}; +/// Disk format: raw file passthrough. +pub const DISK_FORMAT_RAW: u32 = 0; +/// Disk format: qcow2 image. +pub const DISK_FORMAT_QCOW2: u32 = 1; + +/// Abstract block device I/O. +/// +/// Backends translate guest sector reads/writes to the underlying +/// storage format (raw file, qcow2 image, etc.). +pub trait DiskBackend: Send { + /// Read `buf.len()` bytes starting at `offset` into `buf`. + fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> Result<()>; + + /// Write `buf` starting at `offset`. + fn write_at(&mut self, offset: u64, buf: &[u8]) -> Result<()>; + + /// Flush pending writes to stable storage. + fn flush(&mut self) -> Result<()>; + + /// Virtual disk size in bytes. + fn capacity_bytes(&self) -> u64; +} + +// --------------------------------------------------------------------------- +// Raw disk backend +// --------------------------------------------------------------------------- + +/// Raw file-backed disk — direct passthrough to the host file. +pub struct RawDiskBackend { + file: File, + capacity: u64, +} + +impl RawDiskBackend { + /// Wrap an open file as a raw disk backend. + /// + /// The file size must be > 0 (i.e., the file must not be empty). + pub fn new(file: File) -> Result { + let metadata = file + .metadata() + .map_err(|e| WkrunError::Device(format!("failed to get disk metadata: {}", e)))?; + let capacity = metadata.len(); + if capacity == 0 { + return Err(WkrunError::Device("disk file is empty".into())); + } + Ok(RawDiskBackend { file, capacity }) + } +} + +impl DiskBackend for RawDiskBackend { + fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> Result<()> { + self.file + .seek(SeekFrom::Start(offset)) + .map_err(|e| WkrunError::Device(format!("disk seek failed: {}", e)))?; + self.file + .read_exact(buf) + .map_err(|e| WkrunError::Device(format!("disk read failed: {}", e)))?; + Ok(()) + } + + fn write_at(&mut self, offset: u64, buf: &[u8]) -> Result<()> { + self.file + .seek(SeekFrom::Start(offset)) + .map_err(|e| WkrunError::Device(format!("disk seek failed: {}", e)))?; + self.file + .write_all(buf) + .map_err(|e| WkrunError::Device(format!("disk write failed: {}", e)))?; + Ok(()) + } + + fn flush(&mut self) -> Result<()> { + self.file + .sync_all() + .map_err(|e| WkrunError::Device(format!("disk flush failed: {}", e)))?; + Ok(()) + } + + fn capacity_bytes(&self) -> u64 { + self.capacity + } +} + +// --------------------------------------------------------------------------- +// qcow2 disk backend +// --------------------------------------------------------------------------- + +/// qcow2 magic number: 'Q', 'F', 'I', 0xFB. +const QCOW2_MAGIC: u32 = 0x514649FB; + +/// Mask to extract the cluster-aligned file offset from an L1 or L2 entry. +/// Bits 55:9 — zeroes out the top flag bits and the sub-cluster offset. +const L2_OFFSET_MASK: u64 = 0x00FF_FFFF_FFFF_FE00; + +/// Parsed qcow2 header (fields common to v2 and v3). +#[derive(Debug)] +struct Qcow2Header { + #[allow(dead_code)] + version: u32, + backing_file_offset: u64, // File offset of the backing file path (0 = none). + backing_file_size: u32, // Length of the backing file path in bytes. + cluster_bits: u32, + size: u64, // Virtual disk size in bytes. + l1_size: u32, // Number of entries in the L1 table. + l1_table_offset: u64, // File offset of the L1 table. + refcount_table_offset: u64, + refcount_table_clusters: u32, + refcount_order: u32, // log2(refcount bits); 4 means 16-bit refcounts. +} + +impl Qcow2Header { + /// Parse a qcow2 header from the first 104 bytes of the file. + fn parse(buf: &[u8; 104]) -> Result { + let magic = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]); + if magic != QCOW2_MAGIC { + return Err(WkrunError::Device(format!( + "not a qcow2 image: bad magic 0x{:08X}", + magic + ))); + } + + let version = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); + if version != 2 && version != 3 { + return Err(WkrunError::Device(format!( + "unsupported qcow2 version: {}", + version + ))); + } + + // Backing file offset and size (parsed but not validated here — + // the backend's open() method handles backing file resolution). + let backing_file_offset = u64::from_be_bytes(buf[8..16].try_into().unwrap()); + let backing_file_size = u32::from_be_bytes(buf[16..20].try_into().unwrap()); + + let cluster_bits = u32::from_be_bytes([buf[20], buf[21], buf[22], buf[23]]); + if !(9..=21).contains(&cluster_bits) { + return Err(WkrunError::Device(format!( + "invalid qcow2 cluster_bits: {}", + cluster_bits + ))); + } + + let size = u64::from_be_bytes(buf[24..32].try_into().unwrap()); + let crypt_method = u32::from_be_bytes([buf[32], buf[33], buf[34], buf[35]]); + if crypt_method != 0 { + return Err(WkrunError::Device( + "qcow2 encryption is not supported".into(), + )); + } + + let l1_size = u32::from_be_bytes([buf[36], buf[37], buf[38], buf[39]]); + let l1_table_offset = u64::from_be_bytes(buf[40..48].try_into().unwrap()); + let refcount_table_offset = u64::from_be_bytes(buf[48..56].try_into().unwrap()); + let refcount_table_clusters = u32::from_be_bytes([buf[56], buf[57], buf[58], buf[59]]); + + let nb_snapshots = u32::from_be_bytes([buf[60], buf[61], buf[62], buf[63]]); + if nb_snapshots != 0 { + return Err(WkrunError::Device( + "qcow2 snapshots are not supported".into(), + )); + } + + // v3 has refcount_order at offset 96; v2 defaults to 4 (16-bit). + let refcount_order = if version >= 3 { + u32::from_be_bytes([buf[96], buf[97], buf[98], buf[99]]) + } else { + 4 + }; + + Ok(Qcow2Header { + version, + backing_file_offset, + backing_file_size, + cluster_bits, + size, + l1_size, + l1_table_offset, + refcount_table_offset, + refcount_table_clusters, + refcount_order, + }) + } +} + +/// Detect disk format by checking for QCOW2 magic bytes. +fn detect_disk_format(path: &Path) -> Result { + let mut f = File::open(path).map_err(|e| { + WkrunError::Device(format!( + "failed to open '{}' for format detection: {}", + path.display(), + e + )) + })?; + let mut magic = [0u8; 4]; + f.read_exact(&mut magic).map_err(|e| { + WkrunError::Device(format!( + "failed to read magic from '{}': {}", + path.display(), + e + )) + })?; + if u32::from_be_bytes(magic) == QCOW2_MAGIC { + Ok(DISK_FORMAT_QCOW2) + } else { + Ok(DISK_FORMAT_RAW) + } +} + +/// qcow2 image backend with two-level L1/L2 table navigation. +/// +/// Supports reading and writing existing qcow2 images. New clusters +/// are allocated by appending to the end of the file (append-only). +/// Unallocated clusters delegate to an optional backing file. +/// No compression, encryption, or snapshot support. +struct Qcow2DiskBackend { + file: File, + header: Qcow2Header, + cluster_size: u64, + l2_entries_per_table: u64, + l1_table: Vec, + refcount_table: Vec, + next_free_cluster: u64, + read_only: bool, + backing: Option>, +} + +impl Qcow2DiskBackend { + /// Open a qcow2 image file and parse its metadata. + fn open(path: &Path, read_only: bool) -> Result { + let mut file = File::options() + .read(true) + .write(!read_only) + .open(path) + .map_err(|e| { + WkrunError::Device(format!( + "failed to open qcow2 disk '{}': {}", + path.display(), + e + )) + })?; + + // Read header. + let mut header_buf = [0u8; 104]; + file.read_exact(&mut header_buf) + .map_err(|e| WkrunError::Device(format!("failed to read qcow2 header: {}", e)))?; + let header = Qcow2Header::parse(&header_buf)?; + + let cluster_size = 1u64 << header.cluster_bits; + let l2_entries_per_table = cluster_size / 8; + + // Read L1 table. + let l1_byte_len = (header.l1_size as usize) * 8; + let mut l1_bytes = vec![0u8; l1_byte_len]; + file.seek(SeekFrom::Start(header.l1_table_offset)) + .map_err(|e| WkrunError::Device(format!("failed to seek to L1 table: {}", e)))?; + file.read_exact(&mut l1_bytes) + .map_err(|e| WkrunError::Device(format!("failed to read L1 table: {}", e)))?; + let l1_table: Vec = l1_bytes + .chunks_exact(8) + .map(|c| u64::from_be_bytes(c.try_into().unwrap())) + .collect(); + + // Read refcount table. + let refcount_entries = (header.refcount_table_clusters as u64 * cluster_size / 8) as usize; + let mut refcount_bytes = vec![0u8; refcount_entries * 8]; + file.seek(SeekFrom::Start(header.refcount_table_offset)) + .map_err(|e| WkrunError::Device(format!("failed to seek to refcount table: {}", e)))?; + file.read_exact(&mut refcount_bytes) + .map_err(|e| WkrunError::Device(format!("failed to read refcount table: {}", e)))?; + let refcount_table: Vec = refcount_bytes + .chunks_exact(8) + .map(|c| u64::from_be_bytes(c.try_into().unwrap())) + .collect(); + + // Determine next free cluster: end of file rounded up to cluster boundary. + let file_len = file + .seek(SeekFrom::End(0)) + .map_err(|e| WkrunError::Device(format!("failed to get qcow2 file size: {}", e)))?; + let next_free_cluster = file_len.div_ceil(cluster_size) * cluster_size; + + // Open backing file if referenced in the header. + let backing = if header.backing_file_offset != 0 && header.backing_file_size > 0 { + file.seek(SeekFrom::Start(header.backing_file_offset)) + .map_err(|e| { + WkrunError::Device(format!("failed to seek to backing file path: {}", e)) + })?; + let mut path_buf = vec![0u8; header.backing_file_size as usize]; + file.read_exact(&mut path_buf).map_err(|e| { + WkrunError::Device(format!("failed to read backing file path: {}", e)) + })?; + let backing_path_str = String::from_utf8(path_buf).map_err(|e| { + WkrunError::Device(format!("invalid UTF-8 in backing file path: {}", e)) + })?; + + // Resolve relative paths against the parent directory of this qcow2 file. + let backing_path = { + let p = PathBuf::from(&backing_path_str); + if p.is_absolute() { + p + } else { + path.parent().unwrap_or_else(|| Path::new(".")).join(&p) + } + }; + + let backing_format = detect_disk_format(&backing_path)?; + let backend = open_disk_backend(&backing_path, backing_format, true)?; + Some(backend) + } else { + None + }; + + Ok(Qcow2DiskBackend { + file, + header, + cluster_size, + l2_entries_per_table, + l1_table, + refcount_table, + next_free_cluster, + read_only, + backing, + }) + } + + /// Resolve a guest byte offset to a host file offset. + /// Returns `None` if the cluster is unallocated. + fn resolve_offset(&mut self, guest_offset: u64) -> Result> { + let l1_index = (guest_offset / self.cluster_size / self.l2_entries_per_table) as usize; + let l2_index = ((guest_offset / self.cluster_size) % self.l2_entries_per_table) as usize; + let offset_in_cluster = guest_offset % self.cluster_size; + + if l1_index >= self.l1_table.len() { + return Ok(None); + } + + let l1_entry = self.l1_table[l1_index]; + let l2_table_offset = l1_entry & L2_OFFSET_MASK; + if l2_table_offset == 0 { + return Ok(None); + } + + // Read the L2 entry. + let l2_entry_file_offset = l2_table_offset + (l2_index as u64) * 8; + self.file + .seek(SeekFrom::Start(l2_entry_file_offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: failed to seek L2 entry: {}", e)))?; + let mut entry_buf = [0u8; 8]; + self.file + .read_exact(&mut entry_buf) + .map_err(|e| WkrunError::Device(format!("qcow2: failed to read L2 entry: {}", e)))?; + let l2_entry = u64::from_be_bytes(entry_buf); + + let data_cluster_offset = l2_entry & L2_OFFSET_MASK; + if data_cluster_offset == 0 { + return Ok(None); + } + + Ok(Some(data_cluster_offset + offset_in_cluster)) + } + + /// Allocate a new cluster by appending to the file. + /// Updates refcount for the new cluster. + fn allocate_cluster(&mut self) -> Result { + let offset = self.allocate_raw_cluster()?; + self.set_refcount(offset, 1)?; + Ok(offset) + } + + /// Allocate a new cluster without updating refcounts. + /// Used internally to break recursion when allocating refcount blocks. + fn allocate_raw_cluster(&mut self) -> Result { + let offset = self.next_free_cluster; + let zeros = vec![0u8; self.cluster_size as usize]; + self.file + .seek(SeekFrom::Start(offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: seek for alloc failed: {}", e)))?; + self.file + .write_all(&zeros) + .map_err(|e| WkrunError::Device(format!("qcow2: cluster alloc write failed: {}", e)))?; + self.next_free_cluster = offset + self.cluster_size; + Ok(offset) + } + + /// Set the refcount for a cluster at the given file offset. + /// + /// Navigates the two-level refcount table. If the refcount block + /// is missing, allocates one (using raw allocation to avoid recursion). + fn set_refcount(&mut self, cluster_offset: u64, count: u16) -> Result<()> { + let cluster_index = cluster_offset / self.cluster_size; + let refcount_bits = 1u32 << self.header.refcount_order; + let entries_per_block = self.cluster_size * 8 / refcount_bits as u64; + + let refcount_table_index = (cluster_index / entries_per_block) as usize; + let block_index = cluster_index % entries_per_block; + + if refcount_table_index >= self.refcount_table.len() { + // Refcount table too small — skip for now (append-only images + // with limited allocations rarely hit this). + return Ok(()); + } + + let mut block_offset = self.refcount_table[refcount_table_index]; + if block_offset == 0 { + // Allocate a new refcount block (raw — no recursive refcount update). + block_offset = self.allocate_raw_cluster()?; + self.refcount_table[refcount_table_index] = block_offset; + // Write updated refcount table entry back to disk. + let rt_entry_offset = + self.header.refcount_table_offset + (refcount_table_index as u64) * 8; + self.file + .seek(SeekFrom::Start(rt_entry_offset)) + .map_err(|e| { + WkrunError::Device(format!("qcow2: seek refcount table entry: {}", e)) + })?; + self.file + .write_all(&block_offset.to_be_bytes()) + .map_err(|e| { + WkrunError::Device(format!("qcow2: write refcount table entry: {}", e)) + })?; + } + + // Write the 16-bit refcount entry. + let entry_offset = block_offset + block_index * (refcount_bits as u64 / 8); + self.file + .seek(SeekFrom::Start(entry_offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: seek refcount entry: {}", e)))?; + self.file + .write_all(&count.to_be_bytes()) + .map_err(|e| WkrunError::Device(format!("qcow2: write refcount entry: {}", e)))?; + + Ok(()) + } + + /// Ensure an L2 table exists for the given L1 index. Allocates if needed. + /// Returns the file offset of the L2 table. + fn ensure_l2_table(&mut self, l1_index: usize) -> Result { + let l1_entry = self.l1_table[l1_index]; + let l2_offset = l1_entry & L2_OFFSET_MASK; + if l2_offset != 0 { + return Ok(l2_offset); + } + + // Allocate a new L2 table cluster. + let new_l2_offset = self.allocate_cluster()?; + + // Update in-memory L1 table. + self.l1_table[l1_index] = new_l2_offset; + + // Write L1 entry back to disk. + let l1_entry_file_offset = self.header.l1_table_offset + (l1_index as u64) * 8; + self.file + .seek(SeekFrom::Start(l1_entry_file_offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: seek L1 entry: {}", e)))?; + self.file + .write_all(&new_l2_offset.to_be_bytes()) + .map_err(|e| WkrunError::Device(format!("qcow2: write L1 entry: {}", e)))?; + + Ok(new_l2_offset) + } + + /// Ensure a data cluster exists for the given guest offset. + /// Allocates L2 table and/or data cluster if needed. + /// Returns the host file offset for the data. + fn ensure_data_cluster(&mut self, guest_offset: u64) -> Result { + let l1_index = (guest_offset / self.cluster_size / self.l2_entries_per_table) as usize; + let l2_index = ((guest_offset / self.cluster_size) % self.l2_entries_per_table) as usize; + let offset_in_cluster = guest_offset % self.cluster_size; + + if l1_index >= self.l1_table.len() { + return Err(WkrunError::Device(format!( + "qcow2: guest offset {} exceeds virtual size", + guest_offset + ))); + } + + let l2_table_offset = self.ensure_l2_table(l1_index)?; + + // Read the L2 entry. + let l2_entry_file_offset = l2_table_offset + (l2_index as u64) * 8; + self.file + .seek(SeekFrom::Start(l2_entry_file_offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: seek L2 entry: {}", e)))?; + let mut entry_buf = [0u8; 8]; + self.file + .read_exact(&mut entry_buf) + .map_err(|e| WkrunError::Device(format!("qcow2: read L2 entry: {}", e)))?; + let l2_entry = u64::from_be_bytes(entry_buf); + let data_offset = l2_entry & L2_OFFSET_MASK; + + if data_offset != 0 { + return Ok(data_offset + offset_in_cluster); + } + + // Allocate a new data cluster. + let new_data_offset = self.allocate_cluster()?; + + // Write L2 entry back to disk. + self.file + .seek(SeekFrom::Start(l2_entry_file_offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: seek L2 entry for write: {}", e)))?; + self.file + .write_all(&new_data_offset.to_be_bytes()) + .map_err(|e| WkrunError::Device(format!("qcow2: write L2 entry: {}", e)))?; + + Ok(new_data_offset + offset_in_cluster) + } +} + +impl DiskBackend for Qcow2DiskBackend { + fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> Result<()> { + let mut pos = 0usize; + let mut guest_offset = offset; + + while pos < buf.len() { + let offset_in_cluster = guest_offset % self.cluster_size; + let remaining_in_cluster = (self.cluster_size - offset_in_cluster) as usize; + let chunk_len = remaining_in_cluster.min(buf.len() - pos); + + match self.resolve_offset(guest_offset)? { + Some(host_offset) => { + self.file.seek(SeekFrom::Start(host_offset)).map_err(|e| { + WkrunError::Device(format!("qcow2: read seek failed: {}", e)) + })?; + self.file + .read_exact(&mut buf[pos..pos + chunk_len]) + .map_err(|e| WkrunError::Device(format!("qcow2: read failed: {}", e)))?; + } + None => { + // Unallocated cluster — read from backing file or return zeros. + match self.backing { + Some(ref mut b) => { + b.read_at(guest_offset, &mut buf[pos..pos + chunk_len])? + } + None => buf[pos..pos + chunk_len].fill(0), + } + } + } + + pos += chunk_len; + guest_offset += chunk_len as u64; + } + + Ok(()) + } + + fn write_at(&mut self, offset: u64, buf: &[u8]) -> Result<()> { + if self.read_only { + return Err(WkrunError::Device( + "qcow2: write rejected on read-only disk".into(), + )); + } + + let mut pos = 0usize; + let mut guest_offset = offset; + + while pos < buf.len() { + let offset_in_cluster = guest_offset % self.cluster_size; + let remaining_in_cluster = (self.cluster_size - offset_in_cluster) as usize; + let chunk_len = remaining_in_cluster.min(buf.len() - pos); + + let host_offset = self.ensure_data_cluster(guest_offset)?; + + self.file + .seek(SeekFrom::Start(host_offset)) + .map_err(|e| WkrunError::Device(format!("qcow2: write seek failed: {}", e)))?; + self.file + .write_all(&buf[pos..pos + chunk_len]) + .map_err(|e| WkrunError::Device(format!("qcow2: write failed: {}", e)))?; + + pos += chunk_len; + guest_offset += chunk_len as u64; + } + + Ok(()) + } + + fn flush(&mut self) -> Result<()> { + self.file + .sync_all() + .map_err(|e| WkrunError::Device(format!("qcow2: flush failed: {}", e)))?; + Ok(()) + } + + fn capacity_bytes(&self) -> u64 { + self.header.size + } +} + +// --------------------------------------------------------------------------- +// Factory +// --------------------------------------------------------------------------- + +/// Open a disk backend based on the specified format. +/// +/// - `DISK_FORMAT_RAW` (0): raw file passthrough +/// - `DISK_FORMAT_QCOW2` (1): qcow2 image with copy-on-write +pub fn open_disk_backend( + path: &Path, + format: u32, + read_only: bool, +) -> Result> { + match format { + DISK_FORMAT_RAW => { + let file = File::options() + .read(true) + .write(!read_only) + .open(path) + .map_err(|e| { + WkrunError::Device(format!("failed to open disk '{}': {}", path.display(), e)) + })?; + Ok(Box::new(RawDiskBackend::new(file)?)) + } + DISK_FORMAT_QCOW2 => { + let backend = Qcow2DiskBackend::open(path, read_only)?; + Ok(Box::new(backend)) + } + _ => Err(WkrunError::Device(format!( + "unsupported disk format: {}", + format + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write as IoWrite; + use tempfile::NamedTempFile; + + fn create_raw_file(size: usize) -> NamedTempFile { + let mut f = NamedTempFile::new().unwrap(); + f.write_all(&vec![0u8; size]).unwrap(); + f.flush().unwrap(); + f + } + + fn create_raw_file_with_pattern(sectors: u64) -> NamedTempFile { + let mut f = NamedTempFile::new().unwrap(); + for sector in 0..sectors { + let pattern = vec![(sector & 0xFF) as u8; 512]; + f.write_all(&pattern).unwrap(); + } + f.flush().unwrap(); + f + } + + // --- RawDiskBackend --- + + #[test] + fn test_raw_backend_capacity() { + let tmp = create_raw_file(4096); + let file = File::open(tmp.path()).unwrap(); + let backend = RawDiskBackend::new(file).unwrap(); + assert_eq!(backend.capacity_bytes(), 4096); + } + + #[test] + fn test_raw_backend_empty_file_error() { + let tmp = NamedTempFile::new().unwrap(); + let file = File::open(tmp.path()).unwrap(); + assert!(RawDiskBackend::new(file).is_err()); + } + + #[test] + fn test_raw_backend_read_at() { + let tmp = create_raw_file_with_pattern(4); + let file = File::options() + .read(true) + .write(true) + .open(tmp.path()) + .unwrap(); + let mut backend = RawDiskBackend::new(file).unwrap(); + + let mut buf = [0u8; 512]; + backend.read_at(512 * 2, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x02)); + } + + #[test] + fn test_raw_backend_write_at() { + let tmp = create_raw_file(2048); + let file = File::options() + .read(true) + .write(true) + .open(tmp.path()) + .unwrap(); + let mut backend = RawDiskBackend::new(file).unwrap(); + + let data = vec![0xABu8; 512]; + backend.write_at(512, &data).unwrap(); + + let mut buf = [0u8; 512]; + backend.read_at(512, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xAB)); + } + + #[test] + fn test_raw_backend_flush() { + let tmp = create_raw_file(512); + let file = File::options() + .read(true) + .write(true) + .open(tmp.path()) + .unwrap(); + let mut backend = RawDiskBackend::new(file).unwrap(); + backend.flush().unwrap(); + } + + // --- open_disk_backend factory --- + + #[test] + fn test_factory_raw_format() { + let tmp = create_raw_file(1024); + let backend = open_disk_backend(tmp.path(), DISK_FORMAT_RAW, false).unwrap(); + assert_eq!(backend.capacity_bytes(), 1024); + } + + #[test] + fn test_factory_invalid_format() { + let tmp = create_raw_file(1024); + let result = open_disk_backend(tmp.path(), 99, false); + assert!(result.is_err()); + } + + // ----------------------------------------------------------------------- + // qcow2 test helpers + // ----------------------------------------------------------------------- + + /// Create a minimal qcow2 v2 image programmatically. + /// + /// Layout (cluster_size = 512 for small tests): + /// Cluster 0: header + /// Cluster 1: refcount table (1 entry pointing to cluster 2) + /// Cluster 2: refcount block (refcounts for clusters 0..N) + /// Cluster 3: L1 table + /// [Cluster 4+: optional pre-allocated L2 + data] + /// + /// `preallocated` is a list of (guest_byte_offset, data) pairs to + /// write into the image at construction time. + fn create_test_qcow2( + virtual_size: u64, + cluster_bits: u32, + preallocated: &[(u64, &[u8])], + ) -> NamedTempFile { + let cluster_size = 1u64 << cluster_bits; + let l2_entries = cluster_size / 8; + + // Calculate L1 table size. + let l1_entries = virtual_size.div_ceil(cluster_size * l2_entries) as u32; + + // Fixed layout: + // Cluster 0: header + // Cluster 1: refcount table + // Cluster 2: refcount block + // Cluster 3: L1 table (may span multiple clusters but 1 for small tests) + let refcount_table_offset = cluster_size; + let refcount_block_offset = cluster_size * 2; + let l1_table_offset = cluster_size * 3; + let mut next_cluster = cluster_size * 4; // First free cluster. + + // Collect allocations needed for preallocated data. + struct PreallocInfo { + l2_idx: usize, + l2_cluster: u64, + data_cluster: u64, + data: Vec, + data_offset_in_cluster: u64, + } + + let mut l2_clusters: std::collections::HashMap = + std::collections::HashMap::new(); + let mut allocs = Vec::new(); + + for &(guest_offset, data) in preallocated { + let l1_idx = (guest_offset / cluster_size / l2_entries) as usize; + let l2_idx = ((guest_offset / cluster_size) % l2_entries) as usize; + let offset_in_cluster = guest_offset % cluster_size; + + let l2_cluster = *l2_clusters.entry(l1_idx).or_insert_with(|| { + let c = next_cluster; + next_cluster += cluster_size; + c + }); + + let data_cluster = next_cluster; + next_cluster += cluster_size; + + allocs.push(PreallocInfo { + l2_idx, + l2_cluster, + data_cluster, + data: data.to_vec(), + data_offset_in_cluster: offset_in_cluster, + }); + } + + let total_clusters = next_cluster / cluster_size; + let file_size = next_cluster; + + // Build the file. + let mut f = NamedTempFile::new().unwrap(); + let mut image = vec![0u8; file_size as usize]; + + // --- Header (cluster 0) --- + // Magic. + image[0..4].copy_from_slice(&QCOW2_MAGIC.to_be_bytes()); + // Version = 2. + image[4..8].copy_from_slice(&2u32.to_be_bytes()); + // Backing file offset = 0. + image[8..16].copy_from_slice(&0u64.to_be_bytes()); + // Backing file size = 0. + image[16..20].copy_from_slice(&0u32.to_be_bytes()); + // Cluster bits. + image[20..24].copy_from_slice(&cluster_bits.to_be_bytes()); + // Virtual size. + image[24..32].copy_from_slice(&virtual_size.to_be_bytes()); + // Crypt method = 0. + image[32..36].copy_from_slice(&0u32.to_be_bytes()); + // L1 size. + image[36..40].copy_from_slice(&l1_entries.to_be_bytes()); + // L1 table offset. + image[40..48].copy_from_slice(&l1_table_offset.to_be_bytes()); + // Refcount table offset. + image[48..56].copy_from_slice(&refcount_table_offset.to_be_bytes()); + // Refcount table clusters = 1. + image[56..60].copy_from_slice(&1u32.to_be_bytes()); + // Nb snapshots = 0. + image[60..64].copy_from_slice(&0u32.to_be_bytes()); + + // --- Refcount table (cluster 1) --- + // Single entry pointing to refcount block at cluster 2. + let rt_off = refcount_table_offset as usize; + image[rt_off..rt_off + 8].copy_from_slice(&refcount_block_offset.to_be_bytes()); + + // --- Refcount block (cluster 2) --- + // Set refcount=1 for all allocated clusters (16-bit BE entries). + let rb_off = refcount_block_offset as usize; + for i in 0..total_clusters { + let entry_off = rb_off + (i as usize) * 2; + image[entry_off..entry_off + 2].copy_from_slice(&1u16.to_be_bytes()); + } + + // --- L1 table (cluster 3) --- + for (&l1_idx, &l2_cluster) in &l2_clusters { + let entry_off = l1_table_offset as usize + l1_idx * 8; + image[entry_off..entry_off + 8].copy_from_slice(&l2_cluster.to_be_bytes()); + } + + // --- L2 tables + data clusters --- + for alloc in &allocs { + // Write L2 entry. + let l2_entry_off = alloc.l2_cluster as usize + alloc.l2_idx * 8; + image[l2_entry_off..l2_entry_off + 8] + .copy_from_slice(&alloc.data_cluster.to_be_bytes()); + + // Write data. + let data_off = alloc.data_cluster as usize + alloc.data_offset_in_cluster as usize; + let end = data_off + alloc.data.len(); + image[data_off..end].copy_from_slice(&alloc.data); + } + + f.write_all(&image).unwrap(); + f.flush().unwrap(); + f + } + + // ----------------------------------------------------------------------- + // qcow2 header parsing + // ----------------------------------------------------------------------- + + #[test] + fn test_qcow2_header_valid_v2() { + let tmp = create_test_qcow2(1024 * 1024, 16, &[]); + let backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + assert_eq!(backend.header.version, 2); + assert_eq!(backend.header.cluster_bits, 16); + assert_eq!(backend.capacity_bytes(), 1024 * 1024); + } + + #[test] + fn test_qcow2_header_bad_magic() { + let mut tmp = NamedTempFile::new().unwrap(); + let mut data = vec![0u8; 512]; + data[0..4].copy_from_slice(&0xDEADBEEFu32.to_be_bytes()); + tmp.write_all(&data).unwrap(); + tmp.flush().unwrap(); + + let err = Qcow2DiskBackend::open(tmp.path(), false).err().unwrap(); + assert!(err.to_string().contains("bad magic"), "error was: {}", err); + } + + #[test] + fn test_qcow2_header_bad_version() { + let mut tmp = NamedTempFile::new().unwrap(); + let mut data = vec![0u8; 512]; + data[0..4].copy_from_slice(&QCOW2_MAGIC.to_be_bytes()); + data[4..8].copy_from_slice(&1u32.to_be_bytes()); // Version 1. + tmp.write_all(&data).unwrap(); + tmp.flush().unwrap(); + + let err = Qcow2DiskBackend::open(tmp.path(), false).err().unwrap(); + assert!(err.to_string().contains("version"), "error was: {}", err); + } + + #[test] + fn test_qcow2_header_backing_file_parsed() { + // Verify that header parsing accepts backing_file_offset != 0. + let mut buf = [0u8; 104]; + buf[0..4].copy_from_slice(&QCOW2_MAGIC.to_be_bytes()); + buf[4..8].copy_from_slice(&2u32.to_be_bytes()); + buf[8..16].copy_from_slice(&100u64.to_be_bytes()); // Backing file offset. + buf[16..20].copy_from_slice(&10u32.to_be_bytes()); // Backing file size. + buf[20..24].copy_from_slice(&16u32.to_be_bytes()); // cluster_bits. + buf[24..32].copy_from_slice(&(1024u64 * 1024).to_be_bytes()); // size. + buf[36..40].copy_from_slice(&1u32.to_be_bytes()); // l1_size. + buf[40..48].copy_from_slice(&(65536u64).to_be_bytes()); // l1_table_offset. + buf[48..56].copy_from_slice(&(65536u64).to_be_bytes()); // refcount_table_offset. + buf[56..60].copy_from_slice(&1u32.to_be_bytes()); // refcount_table_clusters. + + let header = Qcow2Header::parse(&buf).unwrap(); + assert_eq!(header.backing_file_offset, 100); + assert_eq!(header.backing_file_size, 10); + } + + #[test] + fn test_qcow2_header_encryption_rejected() { + let mut tmp = NamedTempFile::new().unwrap(); + let mut data = vec![0u8; 512]; + data[0..4].copy_from_slice(&QCOW2_MAGIC.to_be_bytes()); + data[4..8].copy_from_slice(&2u32.to_be_bytes()); + data[8..16].copy_from_slice(&0u64.to_be_bytes()); // No backing. + data[20..24].copy_from_slice(&16u32.to_be_bytes()); // cluster_bits. + data[24..32].copy_from_slice(&(1024u64 * 1024).to_be_bytes()); + data[32..36].copy_from_slice(&1u32.to_be_bytes()); // Encrypted! + tmp.write_all(&data).unwrap(); + tmp.flush().unwrap(); + + let err = Qcow2DiskBackend::open(tmp.path(), false).err().unwrap(); + assert!(err.to_string().contains("encryption"), "error was: {}", err); + } + + // ----------------------------------------------------------------------- + // qcow2 reads + // ----------------------------------------------------------------------- + + #[test] + fn test_qcow2_read_unallocated_returns_zeros() { + // 1MB image with no preallocated data, cluster_bits=9 (512B clusters). + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + let mut buf = [0xFFu8; 512]; + backend.read_at(0, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0)); + } + + #[test] + fn test_qcow2_read_allocated_cluster() { + let pattern = vec![0xABu8; 128]; + let tmp = create_test_qcow2(1024 * 1024, 9, &[(512, &pattern)]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + let mut buf = [0u8; 128]; + backend.read_at(512, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xAB)); + } + + #[test] + fn test_qcow2_read_cross_cluster_boundary() { + // Two adjacent clusters with different data. + let data0 = vec![0x11u8; 512]; + let data1 = vec![0x22u8; 512]; + let tmp = create_test_qcow2(1024 * 1024, 9, &[(0, &data0), (512, &data1)]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + // Read 256 bytes spanning the boundary (last 128 of cluster 0 + first 128 of cluster 1). + let mut buf = [0u8; 256]; + backend.read_at(384, &mut buf).unwrap(); + assert!(buf[..128].iter().all(|&b| b == 0x11)); + assert!(buf[128..].iter().all(|&b| b == 0x22)); + } + + #[test] + fn test_qcow2_capacity() { + let tmp = create_test_qcow2(2 * 1024 * 1024, 16, &[]); + let backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + assert_eq!(backend.capacity_bytes(), 2 * 1024 * 1024); + } + + // ----------------------------------------------------------------------- + // qcow2 writes + // ----------------------------------------------------------------------- + + #[test] + fn test_qcow2_write_allocates_cluster() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + let data = vec![0xCDu8; 256]; + backend.write_at(0, &data).unwrap(); + + let mut buf = [0u8; 256]; + backend.read_at(0, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xCD)); + } + + #[test] + fn test_qcow2_write_read_roundtrip() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + // Write different patterns at different offsets. + backend.write_at(0, &[0x11; 512]).unwrap(); + backend.write_at(512, &[0x22; 512]).unwrap(); + backend.write_at(1024, &[0x33; 512]).unwrap(); + + let mut buf0 = [0u8; 512]; + let mut buf1 = [0u8; 512]; + let mut buf2 = [0u8; 512]; + backend.read_at(0, &mut buf0).unwrap(); + backend.read_at(512, &mut buf1).unwrap(); + backend.read_at(1024, &mut buf2).unwrap(); + + assert!(buf0.iter().all(|&b| b == 0x11)); + assert!(buf1.iter().all(|&b| b == 0x22)); + assert!(buf2.iter().all(|&b| b == 0x33)); + } + + #[test] + fn test_qcow2_write_partial_cluster() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + // Write 100 bytes in the middle of cluster 0. + backend.write_at(200, &[0xBB; 100]).unwrap(); + + // Verify: first 200 bytes = zeros, next 100 = 0xBB, rest = zeros. + let mut buf = [0u8; 512]; + backend.read_at(0, &mut buf).unwrap(); + assert!(buf[..200].iter().all(|&b| b == 0x00)); + assert!(buf[200..300].iter().all(|&b| b == 0xBB)); + assert!(buf[300..].iter().all(|&b| b == 0x00)); + } + + #[test] + fn test_qcow2_write_cross_cluster_boundary() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + // Write 256 bytes spanning cluster boundary (cluster_size=512). + let data = vec![0xEE; 256]; + backend.write_at(384, &data).unwrap(); + + let mut buf = [0u8; 256]; + backend.read_at(384, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xEE)); + + // Verify untouched parts. + let mut before = [0u8; 384]; + backend.read_at(0, &mut before).unwrap(); + assert!(before.iter().all(|&b| b == 0x00)); + + let mut after = [0u8; 128]; + backend.read_at(640, &mut after).unwrap(); + assert!(after.iter().all(|&b| b == 0x00)); + } + + #[test] + fn test_qcow2_write_same_cluster_no_realloc() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + backend.write_at(0, &[0x11; 256]).unwrap(); + let free_before = backend.next_free_cluster; + + // Write again to the same cluster — should not allocate new clusters. + backend.write_at(256, &[0x22; 256]).unwrap(); + assert_eq!(backend.next_free_cluster, free_before); + + // Verify both writes persisted. + let mut buf = [0u8; 512]; + backend.read_at(0, &mut buf).unwrap(); + assert!(buf[..256].iter().all(|&b| b == 0x11)); + assert!(buf[256..].iter().all(|&b| b == 0x22)); + } + + #[test] + fn test_qcow2_l2_table_allocation() { + // Use cluster_bits=9 (512B), virtual_size=1MB. + // L2 entries per table = 512/8 = 64. + // So each L1 entry covers 64*512 = 32768 bytes. + // Writing at offset 32768 requires L1 index=1 (new L2 table). + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + + let data = vec![0xAA; 512]; + backend.write_at(32768, &data).unwrap(); + + let mut buf = [0u8; 512]; + backend.read_at(32768, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xAA)); + } + + #[test] + fn test_qcow2_read_only_rejects_writes() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), true).unwrap(); + + let result = backend.write_at(0, &[0x11; 512]); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("read-only"), "error was: {}", err); + } + + #[test] + fn test_qcow2_flush() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = Qcow2DiskBackend::open(tmp.path(), false).unwrap(); + backend.write_at(0, &[0x42; 512]).unwrap(); + backend.flush().unwrap(); + } + + // ----------------------------------------------------------------------- + // Factory: qcow2 dispatch + // ----------------------------------------------------------------------- + + #[test] + fn test_factory_qcow2_format() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let mut backend = open_disk_backend(tmp.path(), DISK_FORMAT_QCOW2, false).unwrap(); + assert_eq!(backend.capacity_bytes(), 1024 * 1024); + + // Write + read through the factory-created backend. + backend.write_at(0, &[0x99; 512]).unwrap(); + let mut buf = [0u8; 512]; + backend.read_at(0, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x99)); + } + + // ----------------------------------------------------------------------- + // Backing file support + // ----------------------------------------------------------------------- + + /// Create a minimal qcow2 v2 image with a backing file reference. + /// + /// Layout (cluster_size = 512): + /// Cluster 0: header + backing file path + /// Cluster 1: refcount table + /// Cluster 2: refcount block + /// Cluster 3: L1 table (all zeros — everything reads from backing) + fn create_test_qcow2_with_backing( + virtual_size: u64, + cluster_bits: u32, + backing_path: &Path, + ) -> NamedTempFile { + let cluster_size = 1u64 << cluster_bits; + let l2_entries = cluster_size / 8; + let l1_entries = virtual_size.div_ceil(cluster_size * l2_entries) as u32; + + let backing_path_bytes = backing_path.to_string_lossy().as_bytes().to_vec(); + let backing_path_len = backing_path_bytes.len() as u32; + // Store backing path right after the 104-byte header. + let backing_file_offset: u64 = 104; + + let refcount_table_offset = cluster_size; + let refcount_block_offset = cluster_size * 2; + let l1_table_offset = cluster_size * 3; + let total_clusters = 4u64; + let file_size = cluster_size * total_clusters; + + let mut f = NamedTempFile::new().unwrap(); + let mut image = vec![0u8; file_size as usize]; + + // --- Header (cluster 0) --- + image[0..4].copy_from_slice(&QCOW2_MAGIC.to_be_bytes()); + image[4..8].copy_from_slice(&2u32.to_be_bytes()); // version + image[8..16].copy_from_slice(&backing_file_offset.to_be_bytes()); + image[16..20].copy_from_slice(&backing_path_len.to_be_bytes()); + image[20..24].copy_from_slice(&cluster_bits.to_be_bytes()); + image[24..32].copy_from_slice(&virtual_size.to_be_bytes()); + image[32..36].copy_from_slice(&0u32.to_be_bytes()); // crypt_method + image[36..40].copy_from_slice(&l1_entries.to_be_bytes()); + image[40..48].copy_from_slice(&l1_table_offset.to_be_bytes()); + image[48..56].copy_from_slice(&refcount_table_offset.to_be_bytes()); + image[56..60].copy_from_slice(&1u32.to_be_bytes()); // refcount_table_clusters + image[60..64].copy_from_slice(&0u32.to_be_bytes()); // nb_snapshots + + // Backing file path (after header). + let start = backing_file_offset as usize; + image[start..start + backing_path_bytes.len()].copy_from_slice(&backing_path_bytes); + + // --- Refcount table (cluster 1) --- + let rt_off = refcount_table_offset as usize; + image[rt_off..rt_off + 8].copy_from_slice(&refcount_block_offset.to_be_bytes()); + + // --- Refcount block (cluster 2) --- + let rb_off = refcount_block_offset as usize; + for i in 0..total_clusters { + let entry_off = rb_off + (i as usize) * 2; + image[entry_off..entry_off + 2].copy_from_slice(&1u16.to_be_bytes()); + } + + // L1 table (cluster 3) — all zeros (everything unallocated → backing). + + f.write_all(&image).unwrap(); + f.flush().unwrap(); + f + } + + #[test] + fn test_qcow2_backing_file_read() { + // Create a raw base disk with a known pattern. + let base = create_raw_file_with_pattern(8); // 8 sectors = 4096 bytes + let base_path = base.path().to_path_buf(); + + // Create a QCOW2 child that references the base as backing. + let child = create_test_qcow2_with_backing(4096, 9, &base_path); + + let mut backend = Qcow2DiskBackend::open(child.path(), false).unwrap(); + + // Read sector 0 — should come from backing (pattern byte = 0x00). + let mut buf = [0u8; 512]; + backend.read_at(0, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x00)); + + // Read sector 3 — should come from backing (pattern byte = 0x03). + backend.read_at(512 * 3, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x03)); + + // Read sector 7 — should come from backing (pattern byte = 0x07). + backend.read_at(512 * 7, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x07)); + } + + #[test] + fn test_qcow2_backing_file_cow_write() { + // Create a raw base disk with pattern. + let base = create_raw_file_with_pattern(8); + let base_path = base.path().to_path_buf(); + + let child = create_test_qcow2_with_backing(4096, 9, &base_path); + let mut backend = Qcow2DiskBackend::open(child.path(), false).unwrap(); + + // Write to sector 2 in the child. + backend.write_at(512 * 2, &[0xFF; 512]).unwrap(); + + // Read sector 2 — should reflect the child write (0xFF). + let mut buf = [0u8; 512]; + backend.read_at(512 * 2, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xFF)); + + // Read sector 3 — should still come from backing (0x03). + backend.read_at(512 * 3, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x03)); + + // Read sector 0 — should still come from backing (0x00). + backend.read_at(0, &mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0x00)); + } + + #[test] + fn test_qcow2_backing_file_missing_errors() { + let missing_path = Path::new("/nonexistent/backing/file.raw"); + let child = create_test_qcow2_with_backing(4096, 9, missing_path); + + let result = Qcow2DiskBackend::open(child.path(), false); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("nonexistent") || err.contains("No such file"), + "error was: {}", + err + ); + } + + #[test] + fn test_detect_disk_format_raw() { + let tmp = create_raw_file(1024); + let fmt = detect_disk_format(tmp.path()).unwrap(); + assert_eq!(fmt, DISK_FORMAT_RAW); + } + + #[test] + fn test_detect_disk_format_qcow2() { + let tmp = create_test_qcow2(1024 * 1024, 9, &[]); + let fmt = detect_disk_format(tmp.path()).unwrap(); + assert_eq!(fmt, DISK_FORMAT_QCOW2); + } +} diff --git a/src/vmm/src/windows/devices/virtio/mmio.rs b/src/vmm/src/windows/devices/virtio/mmio.rs new file mode 100644 index 000000000..c1cd8a4c3 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/mmio.rs @@ -0,0 +1,734 @@ +//! Virtio-MMIO transport (virtio spec v1.2 Section 4.2). +//! +//! Register file at a memory-mapped I/O address. The guest accesses +//! device registers via MMIO reads/writes which trigger VM exits. + +use super::queue::{GuestMemoryAccessor, Virtqueue}; + +/// MMIO base address for the first virtio device. +/// Placed above guest RAM (256MB) and below the 4GB identity map. +pub const VIRTIO_MMIO_BASE: u64 = 0xD000_0000; + +/// Size of the MMIO register region (512 bytes covers all registers + config). +pub const VIRTIO_MMIO_SIZE: u64 = 0x200; + +// Virtio-MMIO register offsets (virtio spec 4.2.2). +const MAGIC_VALUE: u64 = 0x000; +const VERSION: u64 = 0x004; +const DEVICE_ID: u64 = 0x008; +const VENDOR_ID: u64 = 0x00C; +const DEVICE_FEATURES: u64 = 0x010; +const DEVICE_FEATURES_SEL: u64 = 0x014; +const DRIVER_FEATURES: u64 = 0x020; +const DRIVER_FEATURES_SEL: u64 = 0x024; +const QUEUE_SEL: u64 = 0x030; +const QUEUE_NUM_MAX: u64 = 0x034; +const QUEUE_NUM: u64 = 0x038; +const QUEUE_READY: u64 = 0x044; +const QUEUE_NOTIFY: u64 = 0x050; +const INTERRUPT_STATUS: u64 = 0x060; +const INTERRUPT_ACK: u64 = 0x064; +const STATUS: u64 = 0x070; +const QUEUE_DESC_LOW: u64 = 0x080; +const QUEUE_DESC_HIGH: u64 = 0x084; +const QUEUE_AVAIL_LOW: u64 = 0x090; +const QUEUE_AVAIL_HIGH: u64 = 0x094; +const QUEUE_USED_LOW: u64 = 0x0A0; +const QUEUE_USED_HIGH: u64 = 0x0A4; +const CONFIG_GENERATION: u64 = 0x0FC; +const CONFIG_SPACE: u64 = 0x100; + +// Virtio device status bits (virtio spec 2.1) — used in tests. +#[cfg(test)] +const STATUS_ACK: u32 = 1; +#[cfg(test)] +const STATUS_DRIVER: u32 = 2; +#[cfg(test)] +const STATUS_FEATURES_OK: u32 = 8; +#[cfg(test)] +const STATUS_DRIVER_OK: u32 = 4; + +/// Magic value identifying a virtio-MMIO device ("virt" in little-endian). +const VIRTIO_MMIO_MAGIC: u32 = 0x7472_6976; + +/// Virtio-MMIO version (2 = virtio 1.0+). +const VIRTIO_MMIO_VERSION: u32 = 2; + +/// Vendor ID — "QEMU" in little-endian (standard for virtio devices). +/// The Linux kernel's virtio-mmio driver rejects devices with vendor_id == 0. +const VIRTIO_VENDOR_ID: u32 = 0x554D_4551; + +// Interrupt status bits. +const INTERRUPT_USED_RING: u32 = 1; + +/// Backend trait that specific virtio devices implement. +pub trait VirtioDeviceBackend { + /// Virtio device ID (e.g., 2 for block). + fn device_id(&self) -> u32; + + /// Return device feature bits for the given feature page (0 or 1). + fn device_features(&self, page: u32) -> u32; + + /// Read a 32-bit value from the device config space at the given offset. + fn read_config(&self, offset: u64) -> u32; + + /// Write a 32-bit value to the device config space at the given offset. + /// + /// Default: no-op (most devices have read-only config space). + /// Devices with writable config fields (e.g., virtio-balloon `actual`) + /// should override this. + fn write_config(&mut self, _offset: u64, _value: u32) {} + + /// Handle a queue notification (guest made buffers available). + /// + /// Returns `true` if the device processed buffers and an interrupt + /// should be raised. + fn queue_notify( + &mut self, + queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool; + + /// Number of virtqueues this device uses. + fn num_queues(&self) -> usize; + + /// Maximum queue size for the given queue index. + fn queue_max_size(&self, queue_idx: u32) -> u16; + + /// Poll for host-initiated events (e.g., incoming network/vsock data). + /// + /// Called from the vCPU run loop. Returns `true` if an interrupt + /// should be raised (device placed data in the used ring). + /// Default: no host-initiated events (suitable for block devices). + fn poll(&mut self, _queues: &mut [Virtqueue], _mem: &dyn GuestMemoryAccessor) -> bool { + false + } + + /// Drain async I/O completions from a worker thread. + /// + /// Called from the vCPU run loop for devices with async backends + /// (e.g., virtio-blk with a worker thread). Returns `true` if + /// completions were processed and an interrupt should be raised. + /// Default: no async completions. + fn drain_completions( + &mut self, + _queues: &mut [Virtqueue], + _mem: &dyn GuestMemoryAccessor, + ) -> bool { + false + } +} + +/// Virtio-MMIO device wrapping a backend. +pub struct VirtioMmioDevice { + backend: D, + queues: Vec, + /// Currently selected queue index (via QUEUE_SEL). + queue_sel: u32, + /// Device status register. + status: u32, + /// Device feature selection page. + device_features_sel: u32, + /// Driver feature selection page. + driver_features_sel: u32, + /// Driver-acknowledged feature bits (page 0 and page 1). + driver_features: [u32; 2], + /// Interrupt status register. + interrupt_status: u32, +} + +impl VirtioMmioDevice { + /// Create a new MMIO device wrapping the given backend. + pub fn new(backend: D) -> Self { + let num_queues = backend.num_queues(); + let mut queues = Vec::with_capacity(num_queues); + for i in 0..num_queues { + queues.push(Virtqueue::new(backend.queue_max_size(i as u32))); + } + + VirtioMmioDevice { + backend, + queues, + queue_sel: 0, + status: 0, + device_features_sel: 0, + driver_features_sel: 0, + driver_features: [0; 2], + interrupt_status: 0, + } + } + + /// Get a reference to the backend. + pub fn backend(&self) -> &D { + &self.backend + } + + /// Get a mutable reference to the backend. + pub fn backend_mut(&mut self) -> &mut D { + &mut self.backend + } + + /// Get the current interrupt status (non-zero = interrupt pending). + pub fn interrupt_status(&self) -> u32 { + self.interrupt_status + } + + /// Handle an MMIO read at the given offset from the device base. + pub fn read(&self, offset: u64, size: u8) -> u32 { + // All MMIO register reads are 32-bit in virtio-MMIO v2. + if size != 4 && offset < CONFIG_SPACE { + return 0; + } + + match offset { + MAGIC_VALUE => VIRTIO_MMIO_MAGIC, + VERSION => VIRTIO_MMIO_VERSION, + DEVICE_ID => self.backend.device_id(), + VENDOR_ID => VIRTIO_VENDOR_ID, + DEVICE_FEATURES => self.backend.device_features(self.device_features_sel), + QUEUE_NUM_MAX => { + if let Some(q) = self.current_queue() { + q.max_size() as u32 + } else { + 0 + } + } + QUEUE_READY => { + if let Some(q) = self.current_queue() { + q.is_ready() as u32 + } else { + 0 + } + } + INTERRUPT_STATUS => self.interrupt_status, + STATUS => self.status, + CONFIG_GENERATION => 0, // Config doesn't change dynamically. + off if off >= CONFIG_SPACE => { + let config_offset = off - CONFIG_SPACE; + let aligned_offset = config_offset & !3; + let word = self.backend.read_config(aligned_offset); + if size == 4 { + word + } else { + // Byte/word access: extract the correct portion. + // Config space is little-endian; byte N within the u32 is + // at bits (N*8)..(N*8+8). + let byte_index = (config_offset & 3) as u32; + (word >> (byte_index * 8)) & ((1u32 << (size as u32 * 8)) - 1) + } + } + _ => 0, + } + } + + /// Handle an MMIO write at the given offset from the device base. + /// + /// `mem` is needed for queue_notify to process descriptor chains. + /// Returns `true` if an interrupt should be raised. + pub fn write( + &mut self, + offset: u64, + value: u32, + size: u8, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + // All MMIO register writes are 32-bit in virtio-MMIO v2. + if size != 4 { + return false; + } + + match offset { + DEVICE_FEATURES_SEL => { + self.device_features_sel = value; + } + DRIVER_FEATURES => { + let sel = self.driver_features_sel as usize; + if sel < self.driver_features.len() { + self.driver_features[sel] = value; + } + } + DRIVER_FEATURES_SEL => { + self.driver_features_sel = value; + } + QUEUE_SEL => { + self.queue_sel = value; + } + QUEUE_NUM => { + if let Some(q) = self.current_queue_mut() { + q.set_size(value as u16); + } + } + QUEUE_READY => { + if let Some(q) = self.current_queue_mut() { + q.set_ready(value == 1); + } + } + QUEUE_NOTIFY => { + return self.handle_queue_notify(value, mem); + } + INTERRUPT_ACK => { + self.interrupt_status &= !value; + } + STATUS => { + self.handle_status_write(value); + } + QUEUE_DESC_LOW => { + if let Some(q) = self.current_queue_mut() { + let high = 0u64; // Will be combined in set_desc_table. + q.set_desc_table(value as u64 | high); + } + } + QUEUE_DESC_HIGH => { + // High bits for descriptor table address (typically 0 for < 4GB). + } + QUEUE_AVAIL_LOW => { + if let Some(q) = self.current_queue_mut() { + q.set_avail_ring(value as u64); + } + } + QUEUE_AVAIL_HIGH => { + // High bits for avail ring address (typically 0). + } + QUEUE_USED_LOW => { + if let Some(q) = self.current_queue_mut() { + q.set_used_ring(value as u64); + } + } + QUEUE_USED_HIGH => { + // High bits for used ring address (typically 0). + } + off if off >= CONFIG_SPACE => { + let config_offset = off - CONFIG_SPACE; + self.backend.write_config(config_offset, value); + } + _ => {} + } + false + } + + /// Poll the backend for host-initiated events. + /// + /// Returns `true` if an interrupt should be raised. + pub fn poll(&mut self, mem: &dyn GuestMemoryAccessor) -> bool { + let raised = self.backend.poll(&mut self.queues, mem); + if raised { + self.interrupt_status |= INTERRUPT_USED_RING; + } + raised + } + + /// Drain async I/O completions from the backend's worker thread. + /// + /// Returns `true` if completions were processed (interrupt should be raised). + pub fn poll_backend(&mut self, mem: &dyn GuestMemoryAccessor) -> bool { + let raised = self.backend.drain_completions(&mut self.queues, mem); + if raised { + self.interrupt_status |= INTERRUPT_USED_RING; + } + raised + } + + fn current_queue(&self) -> Option<&Virtqueue> { + self.queues.get(self.queue_sel as usize) + } + + fn current_queue_mut(&mut self) -> Option<&mut Virtqueue> { + self.queues.get_mut(self.queue_sel as usize) + } + + fn handle_queue_notify(&mut self, queue_idx: u32, mem: &dyn GuestMemoryAccessor) -> bool { + let idx = queue_idx as usize; + if idx >= self.queues.len() { + return false; + } + + // Split borrow: take queue out, call backend, put it back. + let raised = self + .backend + .queue_notify(queue_idx, &mut self.queues[idx], mem); + + if raised { + self.interrupt_status |= INTERRUPT_USED_RING; + } + + raised + } + + fn handle_status_write(&mut self, value: u32) { + if value == 0 { + // Device reset. + self.status = 0; + self.queue_sel = 0; + self.interrupt_status = 0; + self.device_features_sel = 0; + self.driver_features_sel = 0; + self.driver_features = [0; 2]; + for q in &mut self.queues { + q.reset(); + } + return; + } + // Status can only be set by ORing new bits in. + self.status = value; + } +} + +#[cfg(test)] +mod tests { + use super::super::super::error::Result; + use super::queue::GuestMemoryAccessor; + use super::*; + use std::cell::RefCell; + + /// Null backend for testing the MMIO transport layer. + struct NullBackend; + + impl VirtioDeviceBackend for NullBackend { + fn device_id(&self) -> u32 { + 0 // Invalid/null device. + } + fn device_features(&self, _page: u32) -> u32 { + 0 + } + fn read_config(&self, _offset: u64) -> u32 { + 0 + } + fn queue_notify( + &mut self, + _queue_idx: u32, + _queue: &mut Virtqueue, + _mem: &dyn GuestMemoryAccessor, + ) -> bool { + false + } + fn num_queues(&self) -> usize { + 1 + } + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + 256 + } + } + + /// Test backend that tracks notifications. + struct TestBackend { + notify_count: RefCell, + } + + impl TestBackend { + fn new() -> Self { + TestBackend { + notify_count: RefCell::new(0), + } + } + } + + impl VirtioDeviceBackend for TestBackend { + fn device_id(&self) -> u32 { + 2 // Block device. + } + fn device_features(&self, page: u32) -> u32 { + if page == 0 { + 0x1234 + } else { + 0 + } + } + fn read_config(&self, offset: u64) -> u32 { + if offset == 0 { + 1024 + } else { + 0 + } // Capacity low. + } + fn queue_notify( + &mut self, + _queue_idx: u32, + _queue: &mut Virtqueue, + _mem: &dyn GuestMemoryAccessor, + ) -> bool { + *self.notify_count.borrow_mut() += 1; + true // Raise interrupt. + } + fn num_queues(&self) -> usize { + 1 + } + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + 128 + } + } + + struct MockMem(RefCell>); + impl MockMem { + fn new(size: usize) -> Self { + MockMem(RefCell::new(vec![0u8; size])) + } + } + impl GuestMemoryAccessor for MockMem { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + let a = addr as usize; + let data = self.0.borrow(); + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + let a = addr as usize; + let mut mem = self.0.borrow_mut(); + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + // --- Magic and identification --- + + #[test] + fn test_magic_value() { + let dev = VirtioMmioDevice::new(NullBackend); + assert_eq!(dev.read(MAGIC_VALUE, 4), VIRTIO_MMIO_MAGIC); + } + + #[test] + fn test_version() { + let dev = VirtioMmioDevice::new(NullBackend); + assert_eq!(dev.read(VERSION, 4), 2); + } + + #[test] + fn test_device_id() { + let dev = VirtioMmioDevice::new(TestBackend::new()); + assert_eq!(dev.read(DEVICE_ID, 4), 2); // Block device. + } + + #[test] + fn test_vendor_id() { + let dev = VirtioMmioDevice::new(NullBackend); + assert_eq!(dev.read(VENDOR_ID, 4), 0); + } + + // --- Device features --- + + #[test] + fn test_device_features_page0() { + let mut dev = VirtioMmioDevice::new(TestBackend::new()); + let mem = MockMem::new(64); + dev.write(DEVICE_FEATURES_SEL, 0, 4, &mem); + assert_eq!(dev.read(DEVICE_FEATURES, 4), 0x1234); + } + + #[test] + fn test_device_features_page1() { + let mut dev = VirtioMmioDevice::new(TestBackend::new()); + let mem = MockMem::new(64); + dev.write(DEVICE_FEATURES_SEL, 1, 4, &mem); + assert_eq!(dev.read(DEVICE_FEATURES, 4), 0); + } + + // --- Queue configuration --- + + #[test] + fn test_queue_max_size() { + let dev = VirtioMmioDevice::new(TestBackend::new()); + assert_eq!(dev.read(QUEUE_NUM_MAX, 4), 128); + } + + #[test] + fn test_queue_ready() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + assert_eq!(dev.read(QUEUE_READY, 4), 0); + dev.write(QUEUE_READY, 1, 4, &mem); + assert_eq!(dev.read(QUEUE_READY, 4), 1); + } + + // --- Status state machine --- + + #[test] + fn test_status_ack() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + assert_eq!(dev.read(STATUS, 4), 0); + dev.write(STATUS, STATUS_ACK, 4, &mem); + assert_eq!(dev.read(STATUS, 4), STATUS_ACK); + } + + #[test] + fn test_status_progression() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + dev.write(STATUS, STATUS_ACK, 4, &mem); + dev.write(STATUS, STATUS_ACK | STATUS_DRIVER, 4, &mem); + dev.write( + STATUS, + STATUS_ACK | STATUS_DRIVER | STATUS_FEATURES_OK, + 4, + &mem, + ); + dev.write( + STATUS, + STATUS_ACK | STATUS_DRIVER | STATUS_FEATURES_OK | STATUS_DRIVER_OK, + 4, + &mem, + ); + assert_eq!( + dev.read(STATUS, 4), + STATUS_ACK | STATUS_DRIVER | STATUS_FEATURES_OK | STATUS_DRIVER_OK + ); + } + + #[test] + fn test_status_reset() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + dev.write(STATUS, STATUS_ACK | STATUS_DRIVER, 4, &mem); + assert_ne!(dev.read(STATUS, 4), 0); + dev.write(STATUS, 0, 4, &mem); // Reset. + assert_eq!(dev.read(STATUS, 4), 0); + } + + // --- Interrupt handling --- + + #[test] + fn test_interrupt_on_notify() { + let mut dev = VirtioMmioDevice::new(TestBackend::new()); + let mem = MockMem::new(64); + + assert_eq!(dev.read(INTERRUPT_STATUS, 4), 0); + + // Notify queue 0. + let raised = dev.write(QUEUE_NOTIFY, 0, 4, &mem); + assert!(raised); + assert_eq!(dev.read(INTERRUPT_STATUS, 4), INTERRUPT_USED_RING); + } + + #[test] + fn test_interrupt_ack() { + let mut dev = VirtioMmioDevice::new(TestBackend::new()); + let mem = MockMem::new(64); + + dev.write(QUEUE_NOTIFY, 0, 4, &mem); + assert_eq!(dev.read(INTERRUPT_STATUS, 4), INTERRUPT_USED_RING); + + // Acknowledge the interrupt. + dev.write(INTERRUPT_ACK, INTERRUPT_USED_RING, 4, &mem); + assert_eq!(dev.read(INTERRUPT_STATUS, 4), 0); + } + + // --- Config space --- + + #[test] + fn test_config_space_read() { + let dev = VirtioMmioDevice::new(TestBackend::new()); + // Offset 0x100 = config space offset 0 → capacity low = 1024. + assert_eq!(dev.read(CONFIG_SPACE, 4), 1024); + } + + #[test] + fn test_config_space_byte_reads() { + // Simulates how the Linux virtio-mmio driver reads the MAC: one byte at a time. + // TestBackend returns 1024 (= 0x00000400) at config offset 0. + let dev = VirtioMmioDevice::new(TestBackend::new()); + // 1024 as LE bytes: [0x00, 0x04, 0x00, 0x00] + assert_eq!(dev.read(CONFIG_SPACE + 0, 1), 0x00); // byte 0 + assert_eq!(dev.read(CONFIG_SPACE + 1, 1), 0x04); // byte 1 + assert_eq!(dev.read(CONFIG_SPACE + 2, 1), 0x00); // byte 2 + assert_eq!(dev.read(CONFIG_SPACE + 3, 1), 0x00); // byte 3 + } + + #[test] + fn test_config_space_word_reads() { + let dev = VirtioMmioDevice::new(TestBackend::new()); + // 1024 = 0x0400. Two-byte read at offset 0 should give 0x0400. + assert_eq!(dev.read(CONFIG_SPACE + 0, 2), 0x0400); + // Two-byte read at offset 2 should give 0x0000. + assert_eq!(dev.read(CONFIG_SPACE + 2, 2), 0x0000); + } + + // --- Non-32-bit access --- + + #[test] + fn test_non_32bit_read_returns_zero() { + let dev = VirtioMmioDevice::new(NullBackend); + // Reading magic with size != 4 should return 0. + assert_eq!(dev.read(MAGIC_VALUE, 1), 0); + assert_eq!(dev.read(MAGIC_VALUE, 2), 0); + } + + #[test] + fn test_non_32bit_write_ignored() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + dev.write(STATUS, STATUS_ACK, 2, &mem); // Wrong size. + assert_eq!(dev.read(STATUS, 4), 0); // Should be unchanged. + } + + // --- Invalid queue selection --- + + #[test] + fn test_invalid_queue_sel() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + dev.write(QUEUE_SEL, 99, 4, &mem); + assert_eq!(dev.read(QUEUE_NUM_MAX, 4), 0); // No such queue. + } + + // --- Poll --- + + #[test] + fn test_poll_default_returns_false() { + let mut dev = VirtioMmioDevice::new(NullBackend); + let mem = MockMem::new(64); + assert!(!dev.poll(&mem)); + assert_eq!(dev.interrupt_status(), 0); + } + + /// Backend that returns true from poll(). + struct PollBackend; + + impl VirtioDeviceBackend for PollBackend { + fn device_id(&self) -> u32 { + 19 + } + fn device_features(&self, _page: u32) -> u32 { + 0 + } + fn read_config(&self, _offset: u64) -> u32 { + 0 + } + fn queue_notify( + &mut self, + _queue_idx: u32, + _queue: &mut Virtqueue, + _mem: &dyn GuestMemoryAccessor, + ) -> bool { + false + } + fn num_queues(&self) -> usize { + 1 + } + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + 128 + } + fn poll(&mut self, _queues: &mut [Virtqueue], _mem: &dyn GuestMemoryAccessor) -> bool { + true + } + } + + #[test] + fn test_poll_sets_interrupt_status() { + let mut dev = VirtioMmioDevice::new(PollBackend); + let mem = MockMem::new(64); + let raised = dev.poll(&mem); + assert!(raised); + assert_eq!(dev.interrupt_status(), INTERRUPT_USED_RING); + } + + #[test] + fn test_poll_interrupt_can_be_acked() { + let mut dev = VirtioMmioDevice::new(PollBackend); + let mem = MockMem::new(64); + dev.poll(&mem); + assert_eq!(dev.interrupt_status(), INTERRUPT_USED_RING); + dev.write(INTERRUPT_ACK, INTERRUPT_USED_RING, 4, &mem); + assert_eq!(dev.interrupt_status(), 0); + } +} diff --git a/src/vmm/src/windows/devices/virtio/mod.rs b/src/vmm/src/windows/devices/virtio/mod.rs new file mode 100644 index 000000000..1a2a8d6b6 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/mod.rs @@ -0,0 +1,21 @@ +//! Virtio device emulation. +//! +//! Implements the virtio specification (v1.2) over the MMIO transport +//! for paravirtualized device I/O. Currently supports: +//! - virtio-blk: block device (file-backed disk) +//! - virtio-vsock: socket transport (host TCP <-> guest AF_VSOCK) +//! - virtio-9p: filesystem sharing (host directory <-> guest 9P mount) +//! - virtio-net: network device (userspace proxy via passt/gvproxy) +//! - virtio-rng: entropy source (host OS random) +//! - virtio-balloon: dynamic memory management + +pub mod balloon; +pub mod block; +pub mod block_worker; +pub mod disk; +pub mod mmio; +pub mod net; +pub mod p9; +pub mod queue; +pub mod rng; +pub mod vsock; diff --git a/src/vmm/src/windows/devices/virtio/net.rs b/src/vmm/src/windows/devices/virtio/net.rs new file mode 100644 index 000000000..7c78e4f8c --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/net.rs @@ -0,0 +1,896 @@ +//! Virtio-net device backend (virtio spec v1.2 Section 5.1). +//! +//! Provides a network device backed by a userspace networking proxy +//! (passt/gvproxy) via a stream socket. The wire protocol uses +//! length-prefixed Ethernet frames: `[4-byte BE length][frame bytes]`. +//! +//! Queue layout: +//! Queue 0 (RX): host -> guest (device writes, guest reads) +//! Queue 1 (TX): guest -> host (guest writes, device reads) + +use std::collections::VecDeque; +use std::io::{self, Read, Write}; + +use super::mmio::VirtioDeviceBackend; +use super::queue::{GuestMemoryAccessor, Virtqueue}; + +/// Virtio device ID for network devices. +const VIRTIO_NET_ID: u32 = 1; + +/// VIRTIO_NET_F_MAC — device has given MAC address (bit 5). +const VIRTIO_NET_F_MAC: u32 = 5; + +/// VIRTIO_NET_F_STATUS — device provides link status (bit 16). +const VIRTIO_NET_F_STATUS: u32 = 16; + +/// VIRTIO_F_VERSION_1 — bit 32 (page 1, bit 0). +const VIRTIO_F_VERSION_1_BIT: u32 = 0; + +/// Number of queues: RX and TX (no control queue). +const NUM_QUEUES: usize = 2; + +/// Queue index constants. +const RX_QUEUE: usize = 0; +const TX_QUEUE: usize = 1; + +/// Maximum queue size. +const QUEUE_MAX_SIZE: u16 = 256; + +/// Size of struct virtio_net_hdr_v1 in bytes. +const VIRTIO_NET_HDR_SIZE: usize = 12; + +/// Network link status: up. +const VIRTIO_NET_S_LINK_UP: u16 = 1; + +/// Transport trait for pluggable networking backends. +/// +/// Unix socket transports use the passt/gvproxy wire +/// protocol: each frame is `[4-byte big-endian length][frame bytes]`. +pub trait NetTransport: Send { + /// Try to receive a complete Ethernet frame. Returns `None` if no + /// complete frame is available (non-blocking). + fn recv_frame(&mut self) -> Option>; + + /// Send an Ethernet frame, length-prefixed. + fn send_frame(&mut self, frame: &[u8]) -> io::Result<()>; +} + +/// Receive state machine for length-prefixed framing. +enum RecvState { + /// Waiting for the 4-byte length header; `bytes_read` bytes read so far. + LenPending { bytes_read: usize, buf: [u8; 4] }, + /// Length header complete, reading `frame_len` bytes of frame body. + BodyPending { + frame_len: usize, + buf: Vec, + bytes_read: usize, + }, +} + +impl Default for RecvState { + fn default() -> Self { + RecvState::LenPending { + bytes_read: 0, + buf: [0u8; 4], + } + } +} + +/// Unix stream socket transport (macOS/Linux). +#[cfg(unix)] +pub struct UnixStreamTransport { + stream: std::os::unix::net::UnixStream, + state: RecvState, +} + +#[cfg(unix)] +impl UnixStreamTransport { + /// Wrap a non-blocking Unix stream socket. + pub fn new(stream: std::os::unix::net::UnixStream) -> io::Result { + stream.set_nonblocking(true)?; + Ok(UnixStreamTransport { + stream, + state: RecvState::default(), + }) + } +} + +#[cfg(unix)] +impl NetTransport for UnixStreamTransport { + fn recv_frame(&mut self) -> Option> { + recv_frame_from(&mut self.stream, &mut self.state) + } + + fn send_frame(&mut self, frame: &[u8]) -> io::Result<()> { + send_frame_to(&mut self.stream, frame) + } +} + +/// Unix domain socket transport (Windows, via uds_windows crate). +#[cfg(windows)] +pub struct UdsTransport { + stream: uds_windows::UnixStream, + state: RecvState, +} + +#[cfg(windows)] +impl UdsTransport { + /// Wrap a non-blocking Unix domain socket stream. + pub fn new(stream: uds_windows::UnixStream) -> io::Result { + stream.set_nonblocking(true)?; + Ok(UdsTransport { + stream, + state: RecvState::default(), + }) + } +} + +#[cfg(windows)] +impl NetTransport for UdsTransport { + fn recv_frame(&mut self) -> Option> { + recv_frame_from(&mut self.stream, &mut self.state) + } + + fn send_frame(&mut self, frame: &[u8]) -> io::Result<()> { + send_frame_to(&mut self.stream, frame) + } +} + +/// Shared recv implementation using the state machine. +fn recv_frame_from(reader: &mut R, state: &mut RecvState) -> Option> { + loop { + match state { + RecvState::LenPending { bytes_read, buf } => { + match reader.read(&mut buf[*bytes_read..]) { + Ok(0) => return None, // EOF + Ok(n) => { + *bytes_read += n; + if *bytes_read == 4 { + let frame_len = u32::from_be_bytes(*buf) as usize; + if frame_len == 0 || frame_len > 65536 { + // Invalid frame, reset. + *state = RecvState::default(); + return None; + } + *state = RecvState::BodyPending { + frame_len, + buf: vec![0u8; frame_len], + bytes_read: 0, + }; + // Continue loop to read body. + } + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return None, + Err(_) => return None, + } + } + RecvState::BodyPending { + frame_len, + buf, + bytes_read, + } => { + match reader.read(&mut buf[*bytes_read..]) { + Ok(0) => return None, // EOF + Ok(n) => { + *bytes_read += n; + if *bytes_read == *frame_len { + let frame = std::mem::take(buf); + *state = RecvState::default(); + return Some(frame); + } + // Continue loop to read more. + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return None, + Err(_) => return None, + } + } + } + } +} + +/// Shared send implementation: 4-byte BE length + frame bytes. +fn send_frame_to(writer: &mut W, frame: &[u8]) -> io::Result<()> { + let len = (frame.len() as u32).to_be_bytes(); + writer.write_all(&len)?; + writer.write_all(frame)?; + Ok(()) +} + +/// Generate a MAC address deterministically from a seed. +/// +/// The first three bytes are `52:54:00` (QEMU/KVM OUI prefix). +/// The remaining bytes are derived from `seed`. +pub fn generate_mac(seed: u32) -> [u8; 6] { + let b = seed.to_le_bytes(); + [0x52, 0x54, 0x00, b[0], b[1], b[2]] +} + +/// Virtio-net device backed by a userspace networking proxy. +pub struct VirtioNet { + /// MAC address exposed to the guest. + mac: [u8; 6], + /// Network transport (socket to passt/gvproxy). + transport: Option>, + /// Frames waiting for RX queue space. + rx_pending: VecDeque>, +} + +impl VirtioNet { + /// Create a new virtio-net device with the given MAC and transport. + pub fn new(mac: [u8; 6], transport: Option>) -> Self { + VirtioNet { + mac, + transport, + rx_pending: VecDeque::new(), + } + } + + /// Get the MAC address. + pub fn mac(&self) -> &[u8; 6] { + &self.mac + } + + /// Process the TX queue: read frames from guest, send to transport. + fn process_tx(&mut self, queue: &mut Virtqueue, mem: &dyn GuestMemoryAccessor) -> bool { + let mut processed = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + }; + + if chain.is_empty() { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + + // Collect all data from device-readable descriptors. + let mut data = Vec::new(); + for desc in &chain { + if !desc.is_write() { + let mut buf = vec![0u8; desc.len as usize]; + if mem.read_at(desc.addr, &mut buf).is_ok() { + data.extend_from_slice(&buf); + } + } + } + + // First VIRTIO_NET_HDR_SIZE bytes are the virtio_net_hdr — strip it. + if data.len() > VIRTIO_NET_HDR_SIZE { + let frame = &data[VIRTIO_NET_HDR_SIZE..]; + if let Some(ref mut transport) = self.transport { + let _ = transport.send_frame(frame); + } + } + + let _ = queue.add_used(head, 0, mem); + processed = true; + } + + processed + } + + /// Inject pending frames into the RX queue. + fn inject_rx(&mut self, rx_queue: &mut Virtqueue, mem: &dyn GuestMemoryAccessor) -> bool { + let mut injected = false; + + while !self.rx_pending.is_empty() { + let head = match rx_queue.pop_avail(mem) { + Ok(Some(h)) => h, + _ => break, // No available RX buffers. + }; + + let chain = match rx_queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = rx_queue.add_used(head, 0, mem); + injected = true; + continue; + } + }; + + let frame = self.rx_pending.pop_front().unwrap(); + + // Prepend a zero virtio_net_hdr. + let hdr = [0u8; VIRTIO_NET_HDR_SIZE]; + let total_data: Vec = hdr.iter().chain(frame.iter()).copied().collect(); + + let mut offset = 0; + let mut total_written = 0u32; + for desc in &chain { + if !desc.is_write() { + continue; + } + let remaining = total_data.len().saturating_sub(offset); + let to_write = remaining.min(desc.len as usize); + if to_write > 0 { + let _ = mem.write_at(desc.addr, &total_data[offset..offset + to_write]); + offset += to_write; + total_written += to_write as u32; + } + } + + let _ = rx_queue.add_used(head, total_written, mem); + injected = true; + } + + injected + } +} + +impl VirtioDeviceBackend for VirtioNet { + fn device_id(&self) -> u32 { + VIRTIO_NET_ID + } + + fn device_features(&self, page: u32) -> u32 { + match page { + 0 => (1 << VIRTIO_NET_F_MAC) | (1 << VIRTIO_NET_F_STATUS), + 1 => 1 << VIRTIO_F_VERSION_1_BIT, + _ => 0, + } + } + + fn read_config(&self, offset: u64) -> u32 { + // Config space layout (virtio spec 5.1.4): + // offset 0: mac[0..3] (4 bytes as u32 LE) + // offset 4: mac[4..5] + status (u16 each, packed as u32 LE) + // offset 6: status (u16) — but guest typically reads at offset 4 + match offset { + 0 => u32::from_le_bytes([self.mac[0], self.mac[1], self.mac[2], self.mac[3]]), + 4 => { + // mac[4], mac[5], status_lo, status_hi + let status = VIRTIO_NET_S_LINK_UP; + u32::from_le_bytes([ + self.mac[4], + self.mac[5], + (status & 0xFF) as u8, + ((status >> 8) & 0xFF) as u8, + ]) + } + _ => 0, + } + } + + fn queue_notify( + &mut self, + queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + match queue_idx as usize { + TX_QUEUE => self.process_tx(queue, mem), + _ => false, + } + } + + fn num_queues(&self) -> usize { + NUM_QUEUES + } + + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + QUEUE_MAX_SIZE + } + + fn poll(&mut self, queues: &mut [Virtqueue], mem: &dyn GuestMemoryAccessor) -> bool { + // Drain available frames from the transport. + if let Some(ref mut transport) = self.transport { + while let Some(frame) = transport.recv_frame() { + self.rx_pending.push_back(frame); + } + } + + // Inject pending frames into the RX queue. + if queues.len() > RX_QUEUE { + self.inject_rx(&mut queues[RX_QUEUE], mem) + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::super::super::error::{Result, WkrunError}; + use super::queue::Virtqueue; + use super::*; + use std::cell::RefCell; + + struct MockMem { + data: RefCell>, + } + + impl MockMem { + fn new(size: usize) -> Self { + MockMem { + data: RefCell::new(vec![0u8; size]), + } + } + + fn write_bytes(&self, addr: u64, bytes: &[u8]) { + let a = addr as usize; + let mut data = self.data.borrow_mut(); + data[a..a + bytes.len()].copy_from_slice(bytes); + } + + fn read_bytes(&self, addr: u64, len: usize) -> Vec { + let a = addr as usize; + let data = self.data.borrow(); + data[a..a + len].to_vec() + } + + fn write_u16_at(&self, addr: u64, val: u16) { + self.write_bytes(addr, &val.to_le_bytes()); + } + + fn write_u32_at(&self, addr: u64, val: u32) { + self.write_bytes(addr, &val.to_le_bytes()); + } + + fn write_u64_at(&self, addr: u64, val: u64) { + self.write_bytes(addr, &val.to_le_bytes()); + } + } + + impl GuestMemoryAccessor for MockMem { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + let a = addr as usize; + let data = self.data.borrow(); + if a + buf.len() > data.len() { + return Err(WkrunError::Memory("out of bounds".into())); + } + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + let a = addr as usize; + let mut mem = self.data.borrow_mut(); + if a + data.len() > mem.len() { + return Err(WkrunError::Memory("out of bounds".into())); + } + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + // Memory layout for tests. + const DESC_TABLE: u64 = 0x0000; + const DESC_SIZE: u64 = 16; + const AVAIL_RING: u64 = 0x0800; + const USED_RING: u64 = 0x1000; + const BUF_BASE: u64 = 0x2000; + + fn setup_queue(max_size: u16) -> Virtqueue { + let mut q = Virtqueue::new(max_size); + q.set_size(max_size); + q.set_desc_table(DESC_TABLE); + q.set_avail_ring(AVAIL_RING); + q.set_used_ring(USED_RING); + q.set_ready(true); + q + } + + fn write_descriptor(mem: &MockMem, index: u16, addr: u64, len: u32, flags: u16, next: u16) { + let base = DESC_TABLE + index as u64 * DESC_SIZE; + mem.write_u64_at(base, addr); + mem.write_u32_at(base + 8, len); + mem.write_u16_at(base + 12, flags); + mem.write_u16_at(base + 14, next); + } + + fn push_avail(mem: &MockMem, ring_idx: u16, desc_head: u16) { + let entry_off = AVAIL_RING + 4 + (ring_idx as u64) * 2; + mem.write_u16_at(entry_off, desc_head); + mem.write_u16_at(AVAIL_RING + 2, ring_idx + 1); + } + + /// Mock transport with shared state for inspecting sent frames + /// and injecting received frames after the transport is owned by VirtioNet. + struct SharedMockTransport { + sent: std::sync::Arc>>>, + recv_queue: std::sync::Arc>>>, + } + + impl SharedMockTransport { + fn new() -> (Self, SharedMockHandle) { + let sent = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); + let recv_queue = std::sync::Arc::new(std::sync::Mutex::new(VecDeque::new())); + let handle = SharedMockHandle { + sent: sent.clone(), + recv_queue: recv_queue.clone(), + }; + (SharedMockTransport { sent, recv_queue }, handle) + } + } + + impl NetTransport for SharedMockTransport { + fn recv_frame(&mut self) -> Option> { + self.recv_queue.lock().unwrap().pop_front() + } + + fn send_frame(&mut self, frame: &[u8]) -> io::Result<()> { + self.sent.lock().unwrap().push(frame.to_vec()); + Ok(()) + } + } + + struct SharedMockHandle { + sent: std::sync::Arc>>>, + recv_queue: std::sync::Arc>>>, + } + + impl SharedMockHandle { + fn push_recv(&self, frame: Vec) { + self.recv_queue.lock().unwrap().push_back(frame); + } + + fn sent_frames(&self) -> Vec> { + self.sent.lock().unwrap().clone() + } + } + + fn test_mac() -> [u8; 6] { + [0x52, 0x54, 0x00, 0x12, 0x34, 0x56] + } + + // --- Device identity --- + + #[test] + fn test_device_id() { + let dev = VirtioNet::new(test_mac(), None); + assert_eq!(dev.device_id(), 1); + } + + #[test] + fn test_num_queues() { + let dev = VirtioNet::new(test_mac(), None); + assert_eq!(dev.num_queues(), 2); + } + + #[test] + fn test_queue_max_size() { + let dev = VirtioNet::new(test_mac(), None); + assert_eq!(dev.queue_max_size(0), 256); + assert_eq!(dev.queue_max_size(1), 256); + } + + #[test] + fn test_features_page0() { + let dev = VirtioNet::new(test_mac(), None); + let features = dev.device_features(0); + assert_ne!(features & (1 << VIRTIO_NET_F_MAC), 0); + assert_ne!(features & (1 << VIRTIO_NET_F_STATUS), 0); + } + + #[test] + fn test_features_page1() { + let dev = VirtioNet::new(test_mac(), None); + assert_eq!(dev.device_features(1), 1); // VIRTIO_F_VERSION_1 + } + + // --- Config space --- + + #[test] + fn test_config_mac_offset_0() { + let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]; + let dev = VirtioNet::new(mac, None); + let val = dev.read_config(0); + assert_eq!(val, u32::from_le_bytes([0xAA, 0xBB, 0xCC, 0xDD])); + } + + #[test] + fn test_config_mac_offset_4() { + let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]; + let dev = VirtioNet::new(mac, None); + let val = dev.read_config(4); + // mac[4]=0xEE, mac[5]=0xFF, status=0x0001 (LINK_UP) + assert_eq!(val, u32::from_le_bytes([0xEE, 0xFF, 0x01, 0x00])); + } + + #[test] + fn test_config_status_link_up() { + let dev = VirtioNet::new(test_mac(), None); + let val = dev.read_config(4); + // Status is in bytes 2-3 of the u32 at offset 4. + let status = (val >> 16) as u16; + assert_eq!(status, VIRTIO_NET_S_LINK_UP); + } + + // --- TX queue --- + + #[test] + fn test_tx_sends_frame() { + let (transport, handle) = SharedMockTransport::new(); + let mut dev = VirtioNet::new(test_mac(), Some(Box::new(transport))); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(256); + + // Write virtio_net_hdr (12 zero bytes) + Ethernet frame to guest memory. + let mut tx_data = vec![0u8; VIRTIO_NET_HDR_SIZE]; + let frame = b"\xff\xff\xff\xff\xff\xff\x52\x54\x00\x12\x34\x56\x08\x00hello"; + tx_data.extend_from_slice(frame); + mem.write_bytes(BUF_BASE, &tx_data); + + // Single descriptor: header + frame (device-readable). + write_descriptor(&mem, 0, BUF_BASE, tx_data.len() as u32, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + + let sent = handle.sent_frames(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0], frame); // virtio_net_hdr stripped. + } + + #[test] + fn test_tx_empty_chain_skipped() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(256); + + write_descriptor(&mem, 0, BUF_BASE, 0, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + } + + #[test] + fn test_tx_short_header_skipped() { + let (transport, handle) = SharedMockTransport::new(); + let mut dev = VirtioNet::new(test_mac(), Some(Box::new(transport))); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(256); + + // Only 8 bytes — shorter than virtio_net_hdr. + mem.write_bytes(BUF_BASE, &[0u8; 8]); + write_descriptor(&mem, 0, BUF_BASE, 8, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + assert!(handle.sent_frames().is_empty()); // Nothing sent. + } + + // --- RX queue --- + + #[test] + fn test_rx_inject_frame() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(256); + + let frame = b"\xff\xff\xff\xff\xff\xff\x52\x54\x00\x12\x34\x56\x08\x00data".to_vec(); + dev.rx_pending.push_back(frame.clone()); + + // RX buffer (device-writable). + write_descriptor(&mem, 0, BUF_BASE, 1500, 2, 0); // WRITE flag = 2 + push_avail(&mem, 0, 0); + + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(injected); + + // Check: 12-byte zero header + frame. + let hdr = mem.read_bytes(BUF_BASE, VIRTIO_NET_HDR_SIZE); + assert_eq!(hdr, vec![0u8; VIRTIO_NET_HDR_SIZE]); + let written_frame = mem.read_bytes(BUF_BASE + VIRTIO_NET_HDR_SIZE as u64, frame.len()); + assert_eq!(written_frame, frame); + } + + #[test] + fn test_rx_no_buffers_stays_pending() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(256); + // Don't push any available buffers. + + dev.rx_pending.push_back(b"frame1".to_vec()); + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(!injected); + assert_eq!(dev.rx_pending.len(), 1); + } + + #[test] + fn test_rx_multiple_frames() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(256); + + dev.rx_pending.push_back(b"frame1".to_vec()); + dev.rx_pending.push_back(b"frame2".to_vec()); + + // Two RX buffers. + write_descriptor(&mem, 0, BUF_BASE, 1500, 2, 0); + push_avail(&mem, 0, 0); + write_descriptor(&mem, 1, BUF_BASE + 0x1000, 1500, 2, 0); + push_avail(&mem, 1, 1); + + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(injected); + assert!(dev.rx_pending.is_empty()); + + // Check first frame. + let f1 = mem.read_bytes(BUF_BASE + VIRTIO_NET_HDR_SIZE as u64, 6); + assert_eq!(f1, b"frame1"); + + // Check second frame. + let f2 = mem.read_bytes(BUF_BASE + 0x1000 + VIRTIO_NET_HDR_SIZE as u64, 6); + assert_eq!(f2, b"frame2"); + } + + // --- Poll --- + + #[test] + fn test_poll_reads_transport() { + let (transport, handle) = SharedMockTransport::new(); + let mut dev = VirtioNet::new(test_mac(), Some(Box::new(transport))); + let mem = MockMem::new(0x10000); + + handle.push_recv(b"incoming_frame".to_vec()); + + // Set up RX buffer. + write_descriptor(&mem, 0, BUF_BASE, 1500, 2, 0); + push_avail(&mem, 0, 0); + + let mut queues = vec![setup_queue(256), setup_queue(256)]; + // Point RX queue to our descriptors. + queues[0].set_desc_table(DESC_TABLE); + queues[0].set_avail_ring(AVAIL_RING); + queues[0].set_used_ring(USED_RING); + + let raised = dev.poll(&mut queues, &mem); + assert!(raised); + + // Frame should be in RX queue: 12-byte hdr + "incoming_frame". + let total_len = VIRTIO_NET_HDR_SIZE + 14; + let written = mem.read_bytes(BUF_BASE, total_len); + assert_eq!(&written[..VIRTIO_NET_HDR_SIZE], &[0u8; VIRTIO_NET_HDR_SIZE]); + assert_eq!(&written[VIRTIO_NET_HDR_SIZE..], b"incoming_frame"); + } + + #[test] + fn test_poll_no_data() { + let (transport, _handle) = SharedMockTransport::new(); + let mut dev = VirtioNet::new(test_mac(), Some(Box::new(transport))); + let mem = MockMem::new(0x10000); + + let mut queues = vec![setup_queue(256), setup_queue(256)]; + let raised = dev.poll(&mut queues, &mem); + assert!(!raised); + } + + // --- Frame length prefix encoding/decoding --- + + #[test] + fn test_frame_length_prefix_encode() { + let mut buf = Vec::new(); + let frame = b"test frame data"; + send_frame_to(&mut buf, frame).unwrap(); + + assert_eq!(buf.len(), 4 + frame.len()); + let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]); + assert_eq!(len, frame.len() as u32); + assert_eq!(&buf[4..], frame); + } + + #[test] + fn test_frame_length_prefix_decode() { + let frame = b"hello ethernet"; + let mut wire = Vec::new(); + wire.extend_from_slice(&(frame.len() as u32).to_be_bytes()); + wire.extend_from_slice(frame); + + let mut state = RecvState::default(); + let mut cursor = io::Cursor::new(wire); + let result = recv_frame_from(&mut cursor, &mut state); + assert_eq!(result, Some(frame.to_vec())); + } + + // --- No transport --- + + #[test] + fn test_new_without_transport() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + + // TX should silently drop. + let mut tx_data = vec![0u8; VIRTIO_NET_HDR_SIZE]; + tx_data.extend_from_slice(b"dropped"); + mem.write_bytes(BUF_BASE, &tx_data); + write_descriptor(&mem, 0, BUF_BASE, tx_data.len() as u32, 0, 0); + push_avail(&mem, 0, 0); + let mut tx_queue = setup_queue(256); + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + + // Poll with no transport = false. + let mut queues = vec![setup_queue(256), setup_queue(256)]; + assert!(!dev.poll(&mut queues, &mem)); + } + + // --- MAC generation --- + + #[test] + fn test_mac_generation() { + let mac = generate_mac(42); + assert_eq!(mac[0], 0x52); + assert_eq!(mac[1], 0x54); + assert_eq!(mac[2], 0x00); + // Remaining bytes from seed. + let b = 42u32.to_le_bytes(); + assert_eq!(mac[3], b[0]); + assert_eq!(mac[4], b[1]); + assert_eq!(mac[5], b[2]); + } + + #[test] + fn test_mac_generation_different_seeds() { + let mac1 = generate_mac(1); + let mac2 = generate_mac(2); + // Same OUI prefix. + assert_eq!(&mac1[..3], &mac2[..3]); + // Different generated portion. + assert_ne!(&mac1[3..], &mac2[3..]); + } + + // --- TX with chained descriptors --- + + #[test] + fn test_tx_chained_descriptors() { + let (transport, handle) = SharedMockTransport::new(); + let mut dev = VirtioNet::new(test_mac(), Some(Box::new(transport))); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(256); + + // Descriptor 0: virtio_net_hdr (device-readable), chained to 1. + let hdr = [0u8; VIRTIO_NET_HDR_SIZE]; + mem.write_bytes(BUF_BASE, &hdr); + write_descriptor( + &mem, + 0, + BUF_BASE, + VIRTIO_NET_HDR_SIZE as u32, + 1, // NEXT flag + 1, + ); + + // Descriptor 1: Ethernet frame (device-readable). + let frame = b"ethernet_frame_data"; + mem.write_bytes(BUF_BASE + 0x1000, frame); + write_descriptor(&mem, 1, BUF_BASE + 0x1000, frame.len() as u32, 0, 0); + + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + + let sent = handle.sent_frames(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0], frame); + } + + // --- Queue notify dispatch --- + + #[test] + fn test_queue_notify_rx_returns_false() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(256); + // Notify on RX queue should do nothing. + assert!(!dev.queue_notify(0, &mut rx_queue, &mem)); + } + + #[test] + fn test_queue_notify_invalid_queue() { + let mut dev = VirtioNet::new(test_mac(), None); + let mem = MockMem::new(0x10000); + let mut queue = setup_queue(256); + assert!(!dev.queue_notify(99, &mut queue, &mem)); + } +} diff --git a/src/vmm/src/windows/devices/virtio/p9/filesystem.rs b/src/vmm/src/windows/devices/virtio/p9/filesystem.rs new file mode 100644 index 000000000..a68149df4 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/p9/filesystem.rs @@ -0,0 +1,1089 @@ +//! Host filesystem backend for 9P2000.L. +//! +//! Maps 9P operations to `std::fs` operations on a shared host directory. +//! Each FID maps to an open file or directory path. Security: all paths +//! are resolved relative to the root directory; traversal outside is rejected. + +use std::collections::HashMap; +use std::fs::{self, File, OpenOptions}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use super::protocol::{ByteWriter, P9Attr, Qid, QT_DIR, QT_FILE, QT_SYMLINK}; + +/// Linux errno constants used in Rlerror responses. +pub const ENOENT: u32 = 2; +pub const EIO: u32 = 5; +pub const EBADF: u32 = 9; +pub const EACCES: u32 = 13; +pub const EEXIST: u32 = 17; +pub const ENOTDIR: u32 = 20; +pub const EINVAL: u32 = 22; +pub const ENOSPC: u32 = 28; +pub const ENOTEMPTY: u32 = 39; + +/// Linux open flags. +const O_RDONLY: u32 = 0; +const O_WRONLY: u32 = 1; +const O_RDWR: u32 = 2; +const O_CREAT: u32 = 0o100; +const O_TRUNC: u32 = 0o1000; +const O_APPEND: u32 = 0o2000; + +/// Getattr request mask bits (P9_GETATTR_*). +const P9_GETATTR_MODE: u64 = 0x00000001; +const P9_GETATTR_NLINK: u64 = 0x00000002; +const P9_GETATTR_UID: u64 = 0x00000004; +const P9_GETATTR_GID: u64 = 0x00000008; +const P9_GETATTR_RDEV: u64 = 0x00000010; +const P9_GETATTR_ATIME: u64 = 0x00000020; +const P9_GETATTR_MTIME: u64 = 0x00000040; +const P9_GETATTR_CTIME: u64 = 0x00000080; +const P9_GETATTR_SIZE: u64 = 0x00000200; +const P9_GETATTR_BLOCKS: u64 = 0x00000400; +const P9_GETATTR_BTIME: u64 = 0x00000800; +const P9_GETATTR_GEN: u64 = 0x00001000; +const P9_GETATTR_DATA_VERSION: u64 = 0x00002000; +/// Convenience mask for "all basic fields". +const P9_GETATTR_BASIC: u64 = P9_GETATTR_MODE + | P9_GETATTR_NLINK + | P9_GETATTR_UID + | P9_GETATTR_GID + | P9_GETATTR_RDEV + | P9_GETATTR_ATIME + | P9_GETATTR_MTIME + | P9_GETATTR_CTIME + | P9_GETATTR_SIZE + | P9_GETATTR_BLOCKS + | P9_GETATTR_BTIME + | P9_GETATTR_GEN + | P9_GETATTR_DATA_VERSION; + +/// Setattr valid bits. +const P9_SETATTR_MODE: u32 = 0x00000001; +const P9_SETATTR_SIZE: u32 = 0x00000008; + +/// Unlinkat flags. +const AT_REMOVEDIR: u32 = 0x200; + +/// FID state: tracks an open file or directory path. +struct FidState { + path: PathBuf, + file: Option, +} + +/// Host filesystem backend for 9P. +pub struct P9Filesystem { + root: PathBuf, + fids: HashMap, + read_only: bool, + msize: u32, + /// Path-to-QID-path cache for consistent QID.path values (used on non-Unix). + #[cfg(not(unix))] + qid_cache: HashMap, + /// Next synthetic QID path ID (used on non-Unix when inode not available). + #[cfg(not(unix))] + next_qid_path: u64, +} + +impl P9Filesystem { + pub fn new(root: PathBuf, read_only: bool) -> Self { + P9Filesystem { + root, + fids: HashMap::new(), + read_only, + msize: 0, + #[cfg(not(unix))] + qid_cache: HashMap::new(), + #[cfg(not(unix))] + next_qid_path: 1, + } + } + + /// Get the current msize. + pub fn msize(&self) -> u32 { + self.msize + } + + /// Negotiate protocol version. Returns negotiated msize. + pub fn version(&mut self, client_msize: u32) -> u32 { + self.msize = client_msize.min(65536); + // Release all fids on version (per spec). + self.fids.clear(); + self.msize + } + + /// Attach: bind `fid` to the root directory. + pub fn attach(&mut self, fid: u32) -> Result { + let meta = fs::metadata(&self.root).map_err(|_| ENOENT)?; + let qid = self.make_qid(&self.root.clone(), &meta); + self.fids.insert( + fid, + FidState { + path: self.root.clone(), + file: None, + }, + ); + Ok(qid) + } + + /// Walk: resolve path components from `fid` into `newfid`. + pub fn walk(&mut self, fid: u32, newfid: u32, names: &[String]) -> Result, u32> { + let base_path = self.fids.get(&fid).ok_or(EBADF)?.path.clone(); + + if names.is_empty() { + // Clone fid. + self.fids.insert( + newfid, + FidState { + path: base_path, + file: None, + }, + ); + return Ok(Vec::new()); + } + + let mut current = base_path; + let mut qids = Vec::with_capacity(names.len()); + + for name in names { + if name == ".." || name.contains('/') || name.contains('\\') { + return Err(ENOENT); + } + current = current.join(name); + + // Security: verify the resolved path is under root. + if !self.is_under_root(¤t) { + return Err(EACCES); + } + + let meta = fs::metadata(¤t).map_err(|_| ENOENT)?; + qids.push(self.make_qid(¤t, &meta)); + } + + self.fids.insert( + newfid, + FidState { + path: current, + file: None, + }, + ); + + Ok(qids) + } + + /// Open a file for I/O. + pub fn lopen(&mut self, fid: u32, flags: u32) -> Result<(Qid, u32), u32> { + // Clone path to release borrow on self.fids before calling other &mut self methods. + let path = self.fids.get(&fid).ok_or(EBADF)?.path.clone(); + let meta = fs::metadata(&path).map_err(|_| ENOENT)?; + + if meta.is_dir() { + let qid = self.make_qid_from_parts(&path, &meta); + let iounit = self.iounit(); + return Ok((qid, iounit)); + } + + if self.read_only && (flags & 0x3) != O_RDONLY { + return Err(EACCES); + } + + let file = self.open_file(&path, flags)?; + let qid = self.make_qid_from_parts(&path, &meta); + let iounit = self.iounit(); + self.fids.get_mut(&fid).ok_or(EBADF)?.file = Some(file); + Ok((qid, iounit)) + } + + /// Create and open a new file. + pub fn lcreate( + &mut self, + fid: u32, + name: &str, + _flags: u32, + _mode: u32, + _gid: u32, + ) -> Result<(Qid, u32), u32> { + if self.read_only { + return Err(EACCES); + } + + let dir_path = self.fids.get(&fid).ok_or(EBADF)?.path.clone(); + let file_path = dir_path.join(name); + + if !self.is_under_root(&file_path) { + return Err(EACCES); + } + + let file = OpenOptions::new() + .read(true) + .write(true) + .create_new(true) + .open(&file_path) + .map_err(|e| match e.kind() { + std::io::ErrorKind::AlreadyExists => EEXIST, + std::io::ErrorKind::PermissionDenied => EACCES, + _ => EIO, + })?; + + // Set permissions on Unix. + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = fs::set_permissions(&file_path, fs::Permissions::from_mode(_mode)); + } + + let meta = file.metadata().map_err(|_| EIO)?; + let qid = self.make_qid(&file_path, &meta); + let iounit = self.iounit(); + + // Fid now points to the new file. + let state = self.fids.get_mut(&fid).ok_or(EBADF)?; + state.path = file_path; + state.file = Some(file); + + Ok((qid, iounit)) + } + + /// Read from an open file. + pub fn read(&mut self, fid: u32, offset: u64, count: u32) -> Result, u32> { + let state = self.fids.get_mut(&fid).ok_or(EBADF)?; + let file = state.file.as_mut().ok_or(EBADF)?; + + file.seek(SeekFrom::Start(offset)).map_err(|_| EIO)?; + + let max_read = count.min(self.msize.saturating_sub(11)) as usize; // 11 = header(7) + count(4) + let mut buf = vec![0u8; max_read]; + let n = file.read(&mut buf).map_err(|_| EIO)?; + buf.truncate(n); + Ok(buf) + } + + /// Write to an open file. + pub fn write(&mut self, fid: u32, offset: u64, data: &[u8]) -> Result { + if self.read_only { + return Err(EACCES); + } + + let state = self.fids.get_mut(&fid).ok_or(EBADF)?; + let file = state.file.as_mut().ok_or(EBADF)?; + + file.seek(SeekFrom::Start(offset)).map_err(|_| EIO)?; + file.write_all(data).map_err(|_| ENOSPC)?; + Ok(data.len() as u32) + } + + /// Read directory entries. + pub fn readdir(&mut self, fid: u32, offset: u64, count: u32) -> Result, u32> { + let state = self.fids.get(&fid).ok_or(EBADF)?; + let entries: Vec<_> = fs::read_dir(&state.path) + .map_err(|_| ENOTDIR)? + .filter_map(|e| e.ok()) + .collect(); + + let max_size = count.min(self.msize.saturating_sub(11)) as usize; + let mut w = ByteWriter::with_capacity(max_size); + let mut entry_offset = offset; + + for entry in entries.iter().skip(offset as usize) { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + let meta = match entry.metadata() { + Ok(m) => m, + Err(_) => continue, + }; + + let qid = self.make_qid(&entry.path(), &meta); + let dtype = if meta.is_dir() { 4u8 } else { 8u8 }; + + // Readdir entry: qid[13] + offset[8] + type[1] + name[s] + let entry_size = 13 + 8 + 1 + 2 + name_str.len(); + if w.len() + entry_size > max_size { + break; + } + + entry_offset += 1; + qid.write_to(&mut w); + w.put_u64(entry_offset); + w.put_u8(dtype); + w.put_string(&name_str); + } + + Ok(w.into_bytes()) + } + + /// Get file attributes. + pub fn getattr(&mut self, fid: u32, request_mask: u64) -> Result { + let state = self.fids.get(&fid).ok_or(EBADF)?; + let meta = fs::metadata(&state.path).map_err(|_| ENOENT)?; + let qid = self.make_qid(&state.path.clone(), &meta); + + let valid = request_mask & P9_GETATTR_BASIC; + + let mode = self.metadata_mode(&meta); + let size = meta.len(); + let blksize = 4096u64; + let blocks = size.div_ceil(512); + + // Timestamps. + let (mtime_sec, mtime_nsec) = self.metadata_mtime(&meta); + let (atime_sec, atime_nsec) = self.metadata_atime(&meta); + let (ctime_sec, ctime_nsec) = (mtime_sec, mtime_nsec); // Approximate. + + let nlink = self.metadata_nlink(&meta); + + Ok(P9Attr { + valid, + qid, + mode, + uid: 0, + gid: 0, + nlink, + rdev: 0, + size, + blksize, + blocks, + atime_sec, + atime_nsec, + mtime_sec, + mtime_nsec, + ctime_sec, + ctime_nsec, + btime_sec: 0, + btime_nsec: 0, + gen: 0, + data_version: 0, + }) + } + + /// Set file attributes. + pub fn setattr( + &mut self, + fid: u32, + valid: u32, + mode: u32, + _uid: u32, + _gid: u32, + size: u64, + ) -> Result<(), u32> { + if self.read_only { + return Err(EACCES); + } + + let state = self.fids.get(&fid).ok_or(EBADF)?; + + if valid & P9_SETATTR_MODE != 0 { + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = fs::Permissions::from_mode(mode); + fs::set_permissions(&state.path, perms).map_err(|_| EIO)?; + } + #[cfg(not(unix))] + let _ = mode; + } + + if valid & P9_SETATTR_SIZE != 0 { + let file = OpenOptions::new() + .write(true) + .open(&state.path) + .map_err(|_| EIO)?; + file.set_len(size).map_err(|_| EIO)?; + } + + Ok(()) + } + + /// Release a fid. + pub fn clunk(&mut self, fid: u32) -> Result<(), u32> { + self.fids.remove(&fid).ok_or(EBADF)?; + Ok(()) + } + + /// Create a directory. + pub fn mkdir(&mut self, dfid: u32, name: &str, _mode: u32, _gid: u32) -> Result { + if self.read_only { + return Err(EACCES); + } + + let dir_path = self.fids.get(&dfid).ok_or(EBADF)?.path.clone(); + let new_path = dir_path.join(name); + + if !self.is_under_root(&new_path) { + return Err(EACCES); + } + + fs::create_dir(&new_path).map_err(|e| match e.kind() { + std::io::ErrorKind::AlreadyExists => EEXIST, + std::io::ErrorKind::PermissionDenied => EACCES, + _ => EIO, + })?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = fs::set_permissions(&new_path, fs::Permissions::from_mode(_mode)); + } + + let meta = fs::metadata(&new_path).map_err(|_| EIO)?; + Ok(self.make_qid(&new_path, &meta)) + } + + /// Rename a file or directory. + pub fn renameat( + &mut self, + olddirfid: u32, + oldname: &str, + newdirfid: u32, + newname: &str, + ) -> Result<(), u32> { + if self.read_only { + return Err(EACCES); + } + + let old_dir = self.fids.get(&olddirfid).ok_or(EBADF)?.path.clone(); + let new_dir = self.fids.get(&newdirfid).ok_or(EBADF)?.path.clone(); + + let old_path = old_dir.join(oldname); + let new_path = new_dir.join(newname); + + if !self.is_under_root(&old_path) || !self.is_under_root(&new_path) { + return Err(EACCES); + } + + fs::rename(&old_path, &new_path).map_err(|e| match e.kind() { + std::io::ErrorKind::NotFound => ENOENT, + std::io::ErrorKind::PermissionDenied => EACCES, + _ => EIO, + })?; + + Ok(()) + } + + /// Delete a file or directory. + pub fn unlinkat(&mut self, dirfid: u32, name: &str, flags: u32) -> Result<(), u32> { + if self.read_only { + return Err(EACCES); + } + + let dir_path = self.fids.get(&dirfid).ok_or(EBADF)?.path.clone(); + let target = dir_path.join(name); + + if !self.is_under_root(&target) { + return Err(EACCES); + } + + if flags & AT_REMOVEDIR != 0 { + fs::remove_dir(&target).map_err(|e| match e.kind() { + std::io::ErrorKind::NotFound => ENOENT, + _ => { + // Check if directory is not empty. + if let Ok(mut entries) = fs::read_dir(&target) { + if entries.next().is_some() { + return ENOTEMPTY; + } + } + EIO + } + })?; + } else { + fs::remove_file(&target).map_err(|e| match e.kind() { + std::io::ErrorKind::NotFound => ENOENT, + std::io::ErrorKind::PermissionDenied => EACCES, + _ => EIO, + })?; + } + + Ok(()) + } + + /// Flush cached data to disk. + pub fn fsync(&mut self, fid: u32) -> Result<(), u32> { + let state = self.fids.get_mut(&fid).ok_or(EBADF)?; + if let Some(ref file) = state.file { + // sync_all may fail on read-only files (especially on Windows). + // This is harmless — there's nothing to flush for read-only handles. + let _ = file.sync_all(); + } + Ok(()) + } + + // -- Internal helpers -- + + /// I/O unit size: max data per read/write. + fn iounit(&self) -> u32 { + self.msize.saturating_sub(24) // Conservative: header + read/write overhead. + } + + /// Verify that `path` resolves under the root directory. + fn is_under_root(&self, path: &Path) -> bool { + // Use canonicalize if the path exists; otherwise check components. + if let Ok(canonical) = fs::canonicalize(path) { + if let Ok(root_canonical) = fs::canonicalize(&self.root) { + return canonical.starts_with(&root_canonical); + } + } + // Path doesn't exist yet (e.g., for create). Check the parent. + if let Some(parent) = path.parent() { + if let Ok(canonical_parent) = fs::canonicalize(parent) { + if let Ok(root_canonical) = fs::canonicalize(&self.root) { + return canonical_parent.starts_with(&root_canonical); + } + } + } + false + } + + /// Generate a QID from file metadata. + fn make_qid(&mut self, path: &Path, meta: &fs::Metadata) -> Qid { + self.make_qid_from_parts(path, meta) + } + + fn make_qid_from_parts(&mut self, path: &Path, meta: &fs::Metadata) -> Qid { + let qtype = if meta.is_dir() { + QT_DIR + } else if meta.file_type().is_symlink() { + QT_SYMLINK + } else { + QT_FILE + }; + + let qid_path = self.resolve_qid_path(path, meta); + + let (mtime_sec, _) = self.metadata_mtime(meta); + let version = mtime_sec as u32; + + Qid { + qtype, + version, + path: qid_path, + } + } + + /// Get a unique QID path value for a file. + fn resolve_qid_path(&mut self, path: &Path, meta: &fs::Metadata) -> u64 { + // On Unix: use inode number directly. + #[cfg(unix)] + { + use std::os::unix::fs::MetadataExt; + let _ = path; // suppress unused on non-unix + meta.ino() + } + + // On non-Unix: use a cache mapping canonical paths to synthetic IDs. + #[cfg(not(unix))] + { + let _ = meta; + let canonical = fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf()); + if let Some(&id) = self.qid_cache.get(&canonical) { + id + } else { + let id = self.next_qid_path; + self.next_qid_path += 1; + self.qid_cache.insert(canonical, id); + id + } + } + } + + /// Extract file mode from metadata. + fn metadata_mode(&self, meta: &fs::Metadata) -> u32 { + #[cfg(unix)] + { + use std::os::unix::fs::MetadataExt; + meta.mode() + } + #[cfg(not(unix))] + { + let mut mode = 0o644u32; + if meta.is_dir() { + mode = 0o755 | 0o040000; // S_IFDIR + } else { + mode |= 0o100000; // S_IFREG + } + if meta.permissions().readonly() { + mode &= !0o222; // Remove write bits. + } + mode + } + } + + /// Extract mtime from metadata as (seconds, nanoseconds). + fn metadata_mtime(&self, meta: &fs::Metadata) -> (u64, u64) { + #[cfg(unix)] + { + use std::os::unix::fs::MetadataExt; + (meta.mtime() as u64, meta.mtime_nsec() as u64) + } + #[cfg(not(unix))] + { + meta.modified() + .ok() + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| (d.as_secs(), d.subsec_nanos() as u64)) + .unwrap_or((0, 0)) + } + } + + /// Extract atime from metadata as (seconds, nanoseconds). + fn metadata_atime(&self, meta: &fs::Metadata) -> (u64, u64) { + #[cfg(unix)] + { + use std::os::unix::fs::MetadataExt; + (meta.atime() as u64, meta.atime_nsec() as u64) + } + #[cfg(not(unix))] + { + meta.accessed() + .ok() + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| (d.as_secs(), d.subsec_nanos() as u64)) + .unwrap_or((0, 0)) + } + } + + /// Extract nlink from metadata. + fn metadata_nlink(&self, meta: &fs::Metadata) -> u64 { + #[cfg(unix)] + { + use std::os::unix::fs::MetadataExt; + meta.nlink() + } + #[cfg(not(unix))] + { + let _ = meta; + 1 + } + } + + /// Open a file with Linux open flags mapped to Rust OpenOptions. + fn open_file(&self, path: &Path, flags: u32) -> Result { + let access = flags & 0x3; + let mut opts = OpenOptions::new(); + + match access { + O_RDONLY => { + opts.read(true); + } + O_WRONLY => { + opts.write(true); + } + O_RDWR => { + opts.read(true).write(true); + } + _ => { + opts.read(true); + } + } + + if flags & O_CREAT != 0 { + opts.create(true); + } + if flags & O_TRUNC != 0 { + opts.truncate(true); + } + if flags & O_APPEND != 0 { + opts.append(true); + } + + opts.open(path).map_err(|e| match e.kind() { + std::io::ErrorKind::NotFound => ENOENT, + std::io::ErrorKind::PermissionDenied => EACCES, + _ => EIO, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write as IoWrite; + use tempfile::TempDir; + + fn setup() -> (TempDir, P9Filesystem) { + let tmp = TempDir::new().unwrap(); + let mut fs = P9Filesystem::new(tmp.path().to_path_buf(), false); + fs.version(8192); + (tmp, fs) + } + + fn setup_readonly() -> (TempDir, P9Filesystem) { + let tmp = TempDir::new().unwrap(); + let mut fs = P9Filesystem::new(tmp.path().to_path_buf(), true); + fs.version(8192); + (tmp, fs) + } + + fn create_file(dir: &Path, name: &str, content: &[u8]) { + let path = dir.join(name); + let mut f = File::create(&path).unwrap(); + f.write_all(content).unwrap(); + } + + fn create_subdir(dir: &Path, name: &str) { + fs::create_dir(dir.join(name)).unwrap(); + } + + // -- version -- + + #[test] + fn test_version_negotiates_msize() { + let tmp = TempDir::new().unwrap(); + let mut fs = P9Filesystem::new(tmp.path().to_path_buf(), false); + let msize = fs.version(65536); + assert_eq!(msize, 65536); + assert_eq!(fs.msize(), 65536); + } + + #[test] + fn test_version_caps_msize() { + let tmp = TempDir::new().unwrap(); + let mut fs = P9Filesystem::new(tmp.path().to_path_buf(), false); + let msize = fs.version(1_000_000); + assert_eq!(msize, 65536); // Capped. + } + + // -- attach -- + + #[test] + fn test_attach_returns_dir_qid() { + let (_tmp, mut fs) = setup(); + let qid = fs.attach(0).unwrap(); + assert_eq!(qid.qtype, QT_DIR); + assert_ne!(qid.path, 0); + } + + // -- walk -- + + #[test] + fn test_walk_empty_clones_fid() { + let (_tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + let qids = fs.walk(0, 1, &[]).unwrap(); + assert!(qids.is_empty()); + } + + #[test] + fn test_walk_single_file() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "hello.txt", b"hello"); + fs.attach(0).unwrap(); + let qids = fs.walk(0, 1, &["hello.txt".to_string()]).unwrap(); + assert_eq!(qids.len(), 1); + assert_eq!(qids[0].qtype, QT_FILE); + } + + #[test] + fn test_walk_multiple_components() { + let (tmp, mut fs) = setup(); + create_subdir(tmp.path(), "a"); + create_subdir(&tmp.path().join("a"), "b"); + create_file(&tmp.path().join("a").join("b"), "c.txt", b"content"); + fs.attach(0).unwrap(); + let qids = fs + .walk( + 0, + 1, + &["a".to_string(), "b".to_string(), "c.txt".to_string()], + ) + .unwrap(); + assert_eq!(qids.len(), 3); + assert_eq!(qids[0].qtype, QT_DIR); + assert_eq!(qids[1].qtype, QT_DIR); + assert_eq!(qids[2].qtype, QT_FILE); + } + + #[test] + fn test_walk_nonexistent_returns_error() { + let (_tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + let result = fs.walk(0, 1, &["nonexistent".to_string()]); + assert_eq!(result, Err(ENOENT)); + } + + #[test] + fn test_walk_dotdot_rejected() { + let (_tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + let result = fs.walk(0, 1, &["..".to_string()]); + assert_eq!(result, Err(ENOENT)); + } + + #[test] + fn test_walk_bad_fid() { + let (_tmp, mut fs) = setup(); + let result = fs.walk(99, 1, &["foo".to_string()]); + assert_eq!(result, Err(EBADF)); + } + + // -- lopen + read + write -- + + #[test] + fn test_lopen_and_read() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "data.txt", b"hello world"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["data.txt".to_string()]).unwrap(); + let (qid, iounit) = fs.lopen(1, O_RDONLY).unwrap(); + assert_eq!(qid.qtype, QT_FILE); + assert!(iounit > 0); + + let data = fs.read(1, 0, 4096).unwrap(); + assert_eq!(data, b"hello world"); + } + + #[test] + fn test_read_with_offset() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "data.txt", b"hello world"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["data.txt".to_string()]).unwrap(); + fs.lopen(1, O_RDONLY).unwrap(); + + let data = fs.read(1, 6, 4096).unwrap(); + assert_eq!(data, b"world"); + } + + #[test] + fn test_write_and_readback() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "out.txt", b""); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["out.txt".to_string()]).unwrap(); + fs.lopen(1, O_RDWR).unwrap(); + + let written = fs.write(1, 0, b"test data").unwrap(); + assert_eq!(written, 9); + + let data = fs.read(1, 0, 4096).unwrap(); + assert_eq!(data, b"test data"); + } + + // -- readdir -- + + #[test] + fn test_readdir_lists_entries() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "a.txt", b""); + create_file(tmp.path(), "b.txt", b""); + create_subdir(tmp.path(), "subdir"); + fs.attach(0).unwrap(); + fs.lopen(0, O_RDONLY).unwrap(); + + let data = fs.readdir(0, 0, 8192).unwrap(); + // Should contain directory entries for a.txt, b.txt, subdir. + assert!(!data.is_empty()); + } + + #[test] + fn test_readdir_offset_skips() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "a.txt", b""); + create_file(tmp.path(), "b.txt", b""); + create_file(tmp.path(), "c.txt", b""); + fs.attach(0).unwrap(); + + let full = fs.readdir(0, 0, 8192).unwrap(); + let partial = fs.readdir(0, 1, 8192).unwrap(); + // Partial should be smaller (skipped first entry). + assert!(partial.len() < full.len()); + } + + // -- getattr -- + + #[test] + fn test_getattr_file() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "test.txt", b"12345"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["test.txt".to_string()]).unwrap(); + + let attr = fs.getattr(1, 0x3FFF).unwrap(); + assert_eq!(attr.qid.qtype, QT_FILE); + assert_eq!(attr.size, 5); + assert!(attr.valid != 0); + } + + #[test] + fn test_getattr_dir() { + let (_tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + + let attr = fs.getattr(0, 0x3FFF).unwrap(); + assert_eq!(attr.qid.qtype, QT_DIR); + } + + // -- clunk -- + + #[test] + fn test_clunk_releases_fid() { + let (_tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + fs.clunk(0).unwrap(); + // Fid 0 no longer valid. + assert_eq!(fs.walk(0, 1, &[]), Err(EBADF)); + } + + #[test] + fn test_clunk_bad_fid() { + let (_tmp, mut fs) = setup(); + assert_eq!(fs.clunk(99), Err(EBADF)); + } + + // -- mkdir -- + + #[test] + fn test_mkdir_creates_directory() { + let (tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + let qid = fs.mkdir(0, "newdir", 0o755, 0).unwrap(); + assert_eq!(qid.qtype, QT_DIR); + assert!(tmp.path().join("newdir").is_dir()); + } + + #[test] + fn test_mkdir_already_exists() { + let (tmp, mut fs) = setup(); + create_subdir(tmp.path(), "existing"); + fs.attach(0).unwrap(); + assert_eq!(fs.mkdir(0, "existing", 0o755, 0), Err(EEXIST)); + } + + // -- lcreate -- + + #[test] + fn test_lcreate_creates_file() { + let (tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + let (qid, iounit) = fs.lcreate(0, "new.txt", O_RDWR, 0o644, 0).unwrap(); + assert_eq!(qid.qtype, QT_FILE); + assert!(iounit > 0); + assert!(tmp.path().join("new.txt").exists()); + } + + #[test] + fn test_lcreate_already_exists() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "exists.txt", b""); + fs.attach(0).unwrap(); + assert_eq!(fs.lcreate(0, "exists.txt", O_RDWR, 0o644, 0), Err(EEXIST)); + } + + // -- renameat -- + + #[test] + fn test_renameat() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "old.txt", b"data"); + fs.attach(0).unwrap(); + // Clone fid for newfid. + fs.walk(0, 1, &[]).unwrap(); + fs.renameat(0, "old.txt", 1, "new.txt").unwrap(); + assert!(!tmp.path().join("old.txt").exists()); + assert!(tmp.path().join("new.txt").exists()); + } + + // -- unlinkat -- + + #[test] + fn test_unlinkat_file() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "del.txt", b""); + fs.attach(0).unwrap(); + fs.unlinkat(0, "del.txt", 0).unwrap(); + assert!(!tmp.path().join("del.txt").exists()); + } + + #[test] + fn test_unlinkat_dir() { + let (tmp, mut fs) = setup(); + create_subdir(tmp.path(), "rmdir"); + fs.attach(0).unwrap(); + fs.unlinkat(0, "rmdir", AT_REMOVEDIR).unwrap(); + assert!(!tmp.path().join("rmdir").exists()); + } + + #[test] + fn test_unlinkat_nonempty_dir() { + let (tmp, mut fs) = setup(); + create_subdir(tmp.path(), "notempty"); + create_file(&tmp.path().join("notempty"), "file.txt", b""); + fs.attach(0).unwrap(); + assert_eq!(fs.unlinkat(0, "notempty", AT_REMOVEDIR), Err(ENOTEMPTY)); + } + + // -- fsync -- + + #[test] + fn test_fsync_open_file() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "sync.txt", b"data"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["sync.txt".to_string()]).unwrap(); + fs.lopen(1, O_RDONLY).unwrap(); + fs.fsync(1).unwrap(); + } + + // -- read-only mode -- + + #[test] + fn test_readonly_blocks_write() { + let (tmp, mut fs) = setup_readonly(); + create_file(tmp.path(), "file.txt", b"data"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["file.txt".to_string()]).unwrap(); + // Open for write should fail. + assert_eq!(fs.lopen(1, O_WRONLY), Err(EACCES)); + } + + #[test] + fn test_readonly_blocks_mkdir() { + let (_tmp, mut fs) = setup_readonly(); + fs.attach(0).unwrap(); + assert_eq!(fs.mkdir(0, "new", 0o755, 0), Err(EACCES)); + } + + #[test] + fn test_readonly_blocks_unlink() { + let (tmp, mut fs) = setup_readonly(); + create_file(tmp.path(), "nodel.txt", b""); + fs.attach(0).unwrap(); + assert_eq!(fs.unlinkat(0, "nodel.txt", 0), Err(EACCES)); + } + + #[test] + fn test_readonly_allows_read() { + let (tmp, mut fs) = setup_readonly(); + create_file(tmp.path(), "readable.txt", b"hello"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["readable.txt".to_string()]).unwrap(); + fs.lopen(1, O_RDONLY).unwrap(); + let data = fs.read(1, 0, 4096).unwrap(); + assert_eq!(data, b"hello"); + } + + // -- path traversal security -- + + #[test] + fn test_walk_slash_rejected() { + let (_tmp, mut fs) = setup(); + fs.attach(0).unwrap(); + let result = fs.walk(0, 1, &["a/b".to_string()]); + assert_eq!(result, Err(ENOENT)); + } + + // -- setattr -- + + #[test] + fn test_setattr_truncate() { + let (tmp, mut fs) = setup(); + create_file(tmp.path(), "trunc.txt", b"hello world"); + fs.attach(0).unwrap(); + fs.walk(0, 1, &["trunc.txt".to_string()]).unwrap(); + fs.setattr(1, P9_SETATTR_SIZE, 0, 0, 0, 5).unwrap(); + + // Verify truncation. + let content = std::fs::read(tmp.path().join("trunc.txt")).unwrap(); + assert_eq!(content, b"hello"); + } +} diff --git a/src/vmm/src/windows/devices/virtio/p9/mod.rs b/src/vmm/src/windows/devices/virtio/p9/mod.rs new file mode 100644 index 000000000..a4816ad8e --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/p9/mod.rs @@ -0,0 +1,824 @@ +//! Virtio-9p device backend (virtio spec v1.2 Section 5.11). +//! +//! Provides a 9P2000.L filesystem share between guest and host. +//! The guest mounts the share via `mount -t 9p -o trans=virtio,version=9p2000.L `. +//! +//! Queue layout: +//! Queue 0 (request): bidirectional 9P messages + +pub mod filesystem; +pub mod protocol; + +use std::path::PathBuf; + +use super::mmio::VirtioDeviceBackend; +use super::queue::{GuestMemoryAccessor, Virtqueue}; + +use self::filesystem::P9Filesystem; +use self::protocol::*; + +/// Virtio device ID for 9P transport (spec Section 5.11). +const VIRTIO_9P_ID: u32 = 9; + +/// VIRTIO_F_VERSION_1 — bit 32 (page 1, bit 0). +const VIRTIO_F_VERSION_1_BIT: u32 = 0; + +/// VIRTIO_9P_MOUNT_TAG feature bit (page 0, bit 0). +const VIRTIO_9P_MOUNT_TAG_BIT: u32 = 0; + +/// Maximum queue size. +const QUEUE_MAX_SIZE: u16 = 128; + +/// Virtio-9p device with host filesystem backend. +pub struct Virtio9p { + /// Mount tag visible to the guest (max 255 bytes). + tag: String, + /// Filesystem backend. + fs: P9Filesystem, +} + +impl Virtio9p { + /// Create a new 9p device sharing `root_path` on the host. + /// + /// `tag` is the mount tag the guest uses to identify this share. + /// `root_path` is the host directory to expose. + /// `read_only` controls whether writes are permitted. + pub fn new(tag: &str, root_path: PathBuf, read_only: bool) -> Self { + Virtio9p { + tag: tag.to_string(), + fs: P9Filesystem::new(root_path, read_only), + } + } + + /// Get the mount tag. + pub fn tag(&self) -> &str { + &self.tag + } + + /// Process a single 9P request from a descriptor chain. + /// + /// Returns the response bytes to write back, and the total bytes + /// consumed from readable descriptors. + fn process_request(&mut self, request: &[u8]) -> Vec { + let mut r = ByteReader::new(request); + + let hdr = match P9Header::read_from(&mut r) { + Some(h) => h, + None => return build_response(P9_RLERROR, 0, |w| write_rlerror(w, filesystem::EIO)), + }; + + let body = &request[P9_HEADER_SIZE..]; + let req = match parse_request(hdr.msg_type, body) { + Some(r) => r, + None => { + return build_response(P9_RLERROR, hdr.tag, |w| { + write_rlerror(w, filesystem::EINVAL) + }) + } + }; + + self.dispatch(hdr.tag, req) + } + + /// Dispatch a parsed request to the filesystem backend. + fn dispatch(&mut self, tag: u16, req: P9Request) -> Vec { + match req { + P9Request::Tversion { msize, version } => { + if version != "9P2000.L" { + return build_response(P9_RVERSION, tag, |w| { + write_rversion(w, msize, "unknown"); + }); + } + let negotiated = self.fs.version(msize); + build_response(P9_RVERSION, tag, |w| { + write_rversion(w, negotiated, "9P2000.L"); + }) + } + + P9Request::Tattach { fid, .. } => match self.fs.attach(fid) { + Ok(qid) => build_response(P9_RATTACH, tag, |w| write_rattach(w, &qid)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Twalk { fid, newfid, names } => match self.fs.walk(fid, newfid, &names) { + Ok(qids) => build_response(P9_RWALK, tag, |w| write_rwalk(w, &qids)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tlopen { fid, flags } => match self.fs.lopen(fid, flags) { + Ok((qid, iounit)) => { + build_response(P9_RLOPEN, tag, |w| write_rlopen(w, &qid, iounit)) + } + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tlcreate { + fid, + name, + flags, + mode, + gid, + } => match self.fs.lcreate(fid, &name, flags, mode, gid) { + Ok((qid, iounit)) => { + build_response(P9_RLCREATE, tag, |w| write_rlcreate(w, &qid, iounit)) + } + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tread { fid, offset, count } => match self.fs.read(fid, offset, count) { + Ok(data) => build_response(P9_RREAD, tag, |w| write_rread(w, &data)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Twrite { + fid, offset, data, .. + } => match self.fs.write(fid, offset, &data) { + Ok(count) => build_response(P9_RWRITE, tag, |w| write_rwrite(w, count)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Treaddir { fid, offset, count } => { + match self.fs.readdir(fid, offset, count) { + Ok(data) => build_response(P9_RREADDIR, tag, |w| write_rreaddir(w, &data)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + } + } + + P9Request::Tgetattr { fid, request_mask } => match self.fs.getattr(fid, request_mask) { + Ok(attr) => build_response(P9_RGETATTR, tag, |w| write_rgetattr(w, &attr)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tsetattr { + fid, + valid, + mode, + uid, + gid, + size, + .. + } => match self.fs.setattr(fid, valid, mode, uid, gid, size) { + Ok(()) => build_response(P9_RSETATTR, tag, write_rsetattr), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tclunk { fid } => match self.fs.clunk(fid) { + Ok(()) => build_response(P9_RCLUNK, tag, write_rclunk), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tflush { .. } => build_response(P9_RFLUSH, tag, write_rflush), + + P9Request::Tmkdir { + dfid, + name, + mode, + gid, + } => match self.fs.mkdir(dfid, &name, mode, gid) { + Ok(qid) => build_response(P9_RMKDIR, tag, |w| write_rmkdir(w, &qid)), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Trenameat { + olddirfid, + oldname, + newdirfid, + newname, + } => match self.fs.renameat(olddirfid, &oldname, newdirfid, &newname) { + Ok(()) => build_response(P9_RRENAMEAT, tag, write_rrenameat), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tunlinkat { + dirfid, + name, + flags, + } => match self.fs.unlinkat(dirfid, &name, flags) { + Ok(()) => build_response(P9_RUNLINKAT, tag, write_runlinkat), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + + P9Request::Tfsync { fid } => match self.fs.fsync(fid) { + Ok(()) => build_response(P9_RFSYNC, tag, write_rfsync), + Err(e) => build_response(P9_RLERROR, tag, |w| write_rlerror(w, e)), + }, + } + } +} + +impl VirtioDeviceBackend for Virtio9p { + fn device_id(&self) -> u32 { + VIRTIO_9P_ID + } + + fn device_features(&self, page: u32) -> u32 { + match page { + 0 => 1 << VIRTIO_9P_MOUNT_TAG_BIT, + 1 => 1 << VIRTIO_F_VERSION_1_BIT, + _ => 0, + } + } + + fn read_config(&self, offset: u64) -> u32 { + // Config space layout: + // offset 0: tag_len (u16) — only low 16 bits of the u32 read + // offset 2+: tag bytes (padded to u32 alignment) + let tag_bytes = self.tag.as_bytes(); + let tag_len = tag_bytes.len() as u16; + + if offset == 0 { + // tag_len at offset 0 (u16) + first 2 bytes of tag at offset 2. + let mut val = tag_len as u32; + if !tag_bytes.is_empty() { + val |= (tag_bytes[0] as u32) << 16; + } + if tag_bytes.len() > 1 { + val |= (tag_bytes[1] as u32) << 24; + } + val + } else { + // Subsequent 4-byte reads into the tag string. + // offset is relative to config space start. + // tag starts at byte 2 within config space. + let tag_start = offset as usize - 2; + let mut bytes = [0u8; 4]; + for (i, byte) in bytes.iter_mut().enumerate() { + let tidx = tag_start + i; + if tidx < tag_bytes.len() { + *byte = tag_bytes[tidx]; + } + } + u32::from_le_bytes(bytes) + } + } + + fn queue_notify( + &mut self, + _queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + let mut processed = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + }; + + if chain.is_empty() { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + + // Collect request from device-readable descriptors. + let mut request = Vec::new(); + for desc in &chain { + if !desc.is_write() { + let mut buf = vec![0u8; desc.len as usize]; + if mem.read_at(desc.addr, &mut buf).is_ok() { + request.extend_from_slice(&buf); + } + } + } + + // Process the 9P request. + let response = self.process_request(&request); + + // Write response to device-writable descriptors. + let mut offset = 0; + let mut total_written = 0u32; + for desc in &chain { + if !desc.is_write() { + continue; + } + let remaining = response.len().saturating_sub(offset); + let to_write = remaining.min(desc.len as usize); + if to_write > 0 { + let _ = mem.write_at(desc.addr, &response[offset..offset + to_write]); + offset += to_write; + total_written += to_write as u32; + } + } + + let _ = queue.add_used(head, total_written, mem); + processed = true; + } + + processed + } + + fn num_queues(&self) -> usize { + 1 // Single request queue. + } + + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + QUEUE_MAX_SIZE + } +} + +#[cfg(test)] +mod tests { + use super::super::super::super::error::Result; + use super::super::queue::Virtqueue; + use super::*; + use std::cell::RefCell; + use std::io::Write as IoWrite; + use tempfile::TempDir; + + struct MockMem { + data: RefCell>, + } + + impl MockMem { + fn new(size: usize) -> Self { + MockMem { + data: RefCell::new(vec![0u8; size]), + } + } + + fn write_bytes(&self, addr: u64, bytes: &[u8]) { + let a = addr as usize; + let mut data = self.data.borrow_mut(); + data[a..a + bytes.len()].copy_from_slice(bytes); + } + + fn read_bytes(&self, addr: u64, len: usize) -> Vec { + let a = addr as usize; + let data = self.data.borrow(); + data[a..a + len].to_vec() + } + + fn write_u16_at(&self, addr: u64, val: u16) { + self.write_bytes(addr, &val.to_le_bytes()); + } + + fn write_u32_at(&self, addr: u64, val: u32) { + self.write_bytes(addr, &val.to_le_bytes()); + } + + fn write_u64_at(&self, addr: u64, val: u64) { + self.write_bytes(addr, &val.to_le_bytes()); + } + } + + impl GuestMemoryAccessor for MockMem { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + let a = addr as usize; + let data = self.data.borrow(); + if a + buf.len() > data.len() { + return Err(super::super::super::super::error::WkrunError::Memory( + "out of bounds".into(), + )); + } + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + let a = addr as usize; + let mut mem = self.data.borrow_mut(); + if a + data.len() > mem.len() { + return Err(super::super::super::super::error::WkrunError::Memory( + "out of bounds".into(), + )); + } + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + // Memory layout for tests. + const DESC_TABLE: u64 = 0x0000; + const DESC_SIZE: u64 = 16; + const AVAIL_RING: u64 = 0x0800; + const USED_RING: u64 = 0x1000; + const BUF_BASE: u64 = 0x2000; + const RESP_BASE: u64 = 0x4000; + + fn setup_queue(max_size: u16) -> Virtqueue { + let mut q = Virtqueue::new(max_size); + q.set_size(max_size); + q.set_desc_table(DESC_TABLE); + q.set_avail_ring(AVAIL_RING); + q.set_used_ring(USED_RING); + q.set_ready(true); + q + } + + fn write_descriptor(mem: &MockMem, index: u16, addr: u64, len: u32, flags: u16, next: u16) { + let base = DESC_TABLE + index as u64 * DESC_SIZE; + mem.write_u64_at(base, addr); + mem.write_u32_at(base + 8, len); + mem.write_u16_at(base + 12, flags); + mem.write_u16_at(base + 14, next); + } + + fn push_avail(mem: &MockMem, ring_idx: u16, desc_head: u16) { + let entry_off = AVAIL_RING + 4 + (ring_idx as u64) * 2; + mem.write_u16_at(entry_off, desc_head); + mem.write_u16_at(AVAIL_RING + 2, ring_idx + 1); + } + + fn create_test_device(tmp: &TempDir) -> Virtio9p { + Virtio9p::new("hostshare", tmp.path().to_path_buf(), false) + } + + /// Submit a request through the virtqueue and return the response bytes. + fn submit_request( + dev: &mut Virtio9p, + mem: &MockMem, + queue: &mut Virtqueue, + request: &[u8], + avail_idx: u16, + ) -> Vec { + let desc_base = avail_idx * 2; + mem.write_bytes(BUF_BASE, request); + + // Descriptor 0: request (device-readable), chained to 1. + write_descriptor( + mem, + desc_base, + BUF_BASE, + request.len() as u32, + 1, // NEXT flag + desc_base + 1, + ); + // Descriptor 1: response buffer (device-writable). + write_descriptor(mem, desc_base + 1, RESP_BASE, 8192, 2, 0); // WRITE flag + + push_avail(mem, avail_idx, desc_base); + + let raised = dev.queue_notify(0, queue, mem); + assert!(raised); + + // Read the response from RESP_BASE. + let resp_data = mem.read_bytes(RESP_BASE, 8192); + // Parse size from response. + let size = u32::from_le_bytes([resp_data[0], resp_data[1], resp_data[2], resp_data[3]]); + resp_data[..size as usize].to_vec() + } + + fn build_tversion() -> Vec { + build_response(P9_TVERSION, P9_NOTAG, |w| { + w.put_u32(8192); + w.put_string("9P2000.L"); + }) + } + + fn build_tattach(fid: u32) -> Vec { + build_response(P9_TATTACH, 1, |w| { + w.put_u32(fid); + w.put_u32(P9_NOFID); + w.put_string(""); + w.put_string(""); + }) + } + + fn build_twalk(fid: u32, newfid: u32, names: &[&str]) -> Vec { + build_response(P9_TWALK, 2, |w| { + w.put_u32(fid); + w.put_u32(newfid); + w.put_u16(names.len() as u16); + for name in names { + w.put_string(name); + } + }) + } + + fn build_tlopen(fid: u32, flags: u32) -> Vec { + build_response(P9_TLOPEN, 3, |w| { + w.put_u32(fid); + w.put_u32(flags); + }) + } + + fn build_tread(fid: u32, offset: u64, count: u32) -> Vec { + build_response(P9_TREAD, 4, |w| { + w.put_u32(fid); + w.put_u64(offset); + w.put_u32(count); + }) + } + + fn build_tclunk(fid: u32) -> Vec { + build_response(P9_TCLUNK, 5, |w| { + w.put_u32(fid); + }) + } + + fn build_treaddir(fid: u32, offset: u64, count: u32) -> Vec { + build_response(P9_TREADDIR, 6, |w| { + w.put_u32(fid); + w.put_u64(offset); + w.put_u32(count); + }) + } + + fn build_tgetattr(fid: u32) -> Vec { + build_response(P9_TGETATTR, 7, |w| { + w.put_u32(fid); + w.put_u64(0x3FFF); // All attributes. + }) + } + + // -- Device identity -- + + #[test] + fn test_device_id() { + let tmp = TempDir::new().unwrap(); + let dev = create_test_device(&tmp); + assert_eq!(dev.device_id(), 9); + } + + #[test] + fn test_num_queues() { + let tmp = TempDir::new().unwrap(); + let dev = create_test_device(&tmp); + assert_eq!(dev.num_queues(), 1); + } + + #[test] + fn test_queue_max_size() { + let tmp = TempDir::new().unwrap(); + let dev = create_test_device(&tmp); + assert_eq!(dev.queue_max_size(0), 128); + } + + #[test] + fn test_features() { + let tmp = TempDir::new().unwrap(); + let dev = create_test_device(&tmp); + assert_eq!(dev.device_features(0), 1); // VIRTIO_9P_MOUNT_TAG. + assert_eq!(dev.device_features(1), 1); // VIRTIO_F_VERSION_1. + assert_eq!(dev.device_features(2), 0); + } + + // -- Config space -- + + #[test] + fn test_config_tag_len() { + let tmp = TempDir::new().unwrap(); + let dev = Virtio9p::new("hostshare", tmp.path().to_path_buf(), false); + let val = dev.read_config(0); + // Low 16 bits = tag_len = 9 ("hostshare") + assert_eq!(val & 0xFFFF, 9); + } + + #[test] + fn test_tag() { + let tmp = TempDir::new().unwrap(); + let dev = create_test_device(&tmp); + assert_eq!(dev.tag(), "hostshare"); + } + + // -- Version negotiation -- + + #[test] + fn test_version_negotiation() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + + let resp = dev.process_request(&build_tversion()); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RVERSION); + let msize = r.get_u32().unwrap(); + assert_eq!(msize, 8192); + let version = r.get_string().unwrap(); + assert_eq!(version, "9P2000.L"); + } + + #[test] + fn test_version_unknown_protocol() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + + let msg = build_response(P9_TVERSION, P9_NOTAG, |w| { + w.put_u32(8192); + w.put_string("9P2000.u"); // Not supported. + }); + let resp = dev.process_request(&msg); + let mut r = ByteReader::new(&resp); + let _hdr = P9Header::read_from(&mut r).unwrap(); + let _msize = r.get_u32().unwrap(); + let version = r.get_string().unwrap(); + assert_eq!(version, "unknown"); + } + + // -- Attach -- + + #[test] + fn test_attach() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + dev.process_request(&build_tversion()); + + let resp = dev.process_request(&build_tattach(0)); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RATTACH); + let qid = Qid::read_from(&mut r).unwrap(); + assert_eq!(qid.qtype, QT_DIR); + } + + // -- Walk + Read file through queue -- + + #[test] + fn test_walk_and_read_via_queue() { + let tmp = TempDir::new().unwrap(); + let mut file = std::fs::File::create(tmp.path().join("hello.txt")).unwrap(); + file.write_all(b"hello world").unwrap(); + drop(file); + + let mut dev = create_test_device(&tmp); + let mem = MockMem::new(0x10000); + let mut queue = setup_queue(128); + + // Version. + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tversion(), 0); + + // Attach. + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tattach(0), 1); + + // Walk to hello.txt. + let _resp = submit_request( + &mut dev, + &mem, + &mut queue, + &build_twalk(0, 1, &["hello.txt"]), + 2, + ); + + // Open. + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tlopen(1, 0), 3); + + // Read. + let resp = submit_request(&mut dev, &mem, &mut queue, &build_tread(1, 0, 4096), 4); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RREAD); + let count = r.get_u32().unwrap(); + assert_eq!(count, 11); + let data = r.get_bytes(count as usize).unwrap(); + assert_eq!(data, b"hello world"); + } + + // -- Readdir via queue -- + + #[test] + fn test_readdir_via_queue() { + let tmp = TempDir::new().unwrap(); + std::fs::File::create(tmp.path().join("a.txt")).unwrap(); + std::fs::File::create(tmp.path().join("b.txt")).unwrap(); + + let mut dev = create_test_device(&tmp); + let mem = MockMem::new(0x10000); + let mut queue = setup_queue(128); + + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tversion(), 0); + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tattach(0), 1); + + // Open root dir. + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tlopen(0, 0), 2); + + // Readdir. + let resp = submit_request(&mut dev, &mem, &mut queue, &build_treaddir(0, 0, 8192), 3); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RREADDIR); + let count = r.get_u32().unwrap(); + assert!(count > 0); // Should contain entries. + } + + // -- Getattr -- + + #[test] + fn test_getattr_via_queue() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + let mem = MockMem::new(0x10000); + let mut queue = setup_queue(128); + + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tversion(), 0); + let _resp = submit_request(&mut dev, &mem, &mut queue, &build_tattach(0), 1); + + let resp = submit_request(&mut dev, &mem, &mut queue, &build_tgetattr(0), 2); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RGETATTR); + } + + // -- Error response for bad fid -- + + #[test] + fn test_error_bad_fid() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + dev.process_request(&build_tversion()); + + // Try to walk with unattached fid. + let resp = dev.process_request(&build_twalk(99, 1, &["foo"])); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RLERROR); + let ecode = r.get_u32().unwrap(); + assert_eq!(ecode, filesystem::EBADF); + } + + // -- Clunk -- + + #[test] + fn test_clunk() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + dev.process_request(&build_tversion()); + dev.process_request(&build_tattach(0)); + + let resp = dev.process_request(&build_tclunk(0)); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RCLUNK); + + // Fid 0 should now be invalid. + let resp = dev.process_request(&build_twalk(0, 1, &[])); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RLERROR); + } + + // -- Multiple requests in sequence -- + + #[test] + fn test_multiple_requests() { + let tmp = TempDir::new().unwrap(); + std::fs::File::create(tmp.path().join("f1.txt")).unwrap(); + std::fs::File::create(tmp.path().join("f2.txt")).unwrap(); + + let mut dev = create_test_device(&tmp); + dev.process_request(&build_tversion()); + dev.process_request(&build_tattach(0)); + + // Walk to two different files. + let resp1 = dev.process_request(&build_twalk(0, 1, &["f1.txt"])); + let resp2 = dev.process_request(&build_twalk(0, 2, &["f2.txt"])); + + let mut r1 = ByteReader::new(&resp1); + assert_eq!(P9Header::read_from(&mut r1).unwrap().msg_type, P9_RWALK); + + let mut r2 = ByteReader::new(&resp2); + assert_eq!(P9Header::read_from(&mut r2).unwrap().msg_type, P9_RWALK); + } + + // -- Short/malformed request -- + + #[test] + fn test_malformed_request() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + + // Too short for a header. + let resp = dev.process_request(&[0, 0, 0]); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RLERROR); + } + + // -- Flush -- + + #[test] + fn test_flush() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + + let msg = build_response(P9_TFLUSH, 10, |w| { + w.put_u16(5); // oldtag + }); + let resp = dev.process_request(&msg); + let mut r = ByteReader::new(&resp); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RFLUSH); + assert_eq!(hdr.tag, 10); + } + + // -- Empty chain handled -- + + #[test] + fn test_empty_chain_skipped() { + let tmp = TempDir::new().unwrap(); + let mut dev = create_test_device(&tmp); + let mem = MockMem::new(0x10000); + let mut queue = setup_queue(128); + + // Descriptor with 0 length. + write_descriptor(&mem, 0, BUF_BASE, 0, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.queue_notify(0, &mut queue, &mem); + assert!(processed); + } +} diff --git a/src/vmm/src/windows/devices/virtio/p9/protocol.rs b/src/vmm/src/windows/devices/virtio/p9/protocol.rs new file mode 100644 index 000000000..c9a6b5932 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/p9/protocol.rs @@ -0,0 +1,1316 @@ +//! 9P2000.L wire protocol types and serialization. +//! +//! Implements the 9P2000.L message format used by Linux v9fs. +//! All multi-byte fields are little-endian. Messages have the format: +//! size[4] type[1] tag[2] params... + +// -- 9P message type constants -- + +pub const P9_RLERROR: u8 = 7; +pub const P9_TLOPEN: u8 = 12; +pub const P9_RLOPEN: u8 = 13; +pub const P9_TLCREATE: u8 = 14; +pub const P9_RLCREATE: u8 = 15; +pub const P9_TGETATTR: u8 = 24; +pub const P9_RGETATTR: u8 = 25; +pub const P9_TSETATTR: u8 = 26; +pub const P9_RSETATTR: u8 = 27; +pub const P9_TREADDIR: u8 = 40; +pub const P9_RREADDIR: u8 = 41; +pub const P9_TFSYNC: u8 = 50; +pub const P9_RFSYNC: u8 = 51; +pub const P9_TMKDIR: u8 = 72; +pub const P9_RMKDIR: u8 = 73; +pub const P9_TRENAMEAT: u8 = 74; +pub const P9_RRENAMEAT: u8 = 75; +pub const P9_TUNLINKAT: u8 = 76; +pub const P9_RUNLINKAT: u8 = 77; +pub const P9_TVERSION: u8 = 100; +pub const P9_RVERSION: u8 = 101; +pub const P9_TATTACH: u8 = 104; +pub const P9_RATTACH: u8 = 105; +pub const P9_TFLUSH: u8 = 108; +pub const P9_RFLUSH: u8 = 109; +pub const P9_TWALK: u8 = 110; +pub const P9_RWALK: u8 = 111; +pub const P9_TREAD: u8 = 116; +pub const P9_RREAD: u8 = 117; +pub const P9_TWRITE: u8 = 118; +pub const P9_RWRITE: u8 = 119; +pub const P9_TCLUNK: u8 = 120; +pub const P9_RCLUNK: u8 = 121; + +/// No-fid sentinel. +pub const P9_NOFID: u32 = u32::MAX; + +/// No-tag sentinel (used in Tversion). +pub const P9_NOTAG: u16 = u16::MAX; + +/// 9P message header size (size[4] + type[1] + tag[2]). +pub const P9_HEADER_SIZE: usize = 7; + +/// QID size in bytes (type[1] + version[4] + path[8]). +pub const QID_SIZE: usize = 13; + +/// QID type: directory. +pub const QT_DIR: u8 = 0x80; +/// QID type: regular file. +pub const QT_FILE: u8 = 0x00; +/// QID type: symlink. +pub const QT_SYMLINK: u8 = 0x02; + +/// Default maximum message size. +pub const DEFAULT_MSIZE: u32 = 8192 + P9_HEADER_SIZE as u32; + +// -- QID -- + +/// 13-byte file identifier (type, version, path). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Qid { + pub qtype: u8, + pub version: u32, + pub path: u64, +} + +impl Qid { + pub fn write_to(&self, w: &mut ByteWriter) { + w.put_u8(self.qtype); + w.put_u32(self.version); + w.put_u64(self.path); + } + + pub fn read_from(r: &mut ByteReader) -> Option { + let qtype = r.get_u8()?; + let version = r.get_u32()?; + let path = r.get_u64()?; + Some(Qid { + qtype, + version, + path, + }) + } +} + +// -- P9 message header -- + +/// Parsed 9P message header. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct P9Header { + pub size: u32, + pub msg_type: u8, + pub tag: u16, +} + +impl P9Header { + pub fn read_from(r: &mut ByteReader) -> Option { + let size = r.get_u32()?; + let msg_type = r.get_u8()?; + let tag = r.get_u16()?; + Some(P9Header { + size, + msg_type, + tag, + }) + } + + pub fn write_to(&self, w: &mut ByteWriter) { + w.put_u32(self.size); + w.put_u8(self.msg_type); + w.put_u16(self.tag); + } +} + +// -- Parsed T-message requests -- + +/// Parsed 9P T-message (client request). +#[derive(Debug)] +pub enum P9Request { + Tversion { + msize: u32, + version: String, + }, + Tattach { + fid: u32, + afid: u32, + uname: String, + aname: String, + }, + Twalk { + fid: u32, + newfid: u32, + names: Vec, + }, + Tlopen { + fid: u32, + flags: u32, + }, + Tlcreate { + fid: u32, + name: String, + flags: u32, + mode: u32, + gid: u32, + }, + Tread { + fid: u32, + offset: u64, + count: u32, + }, + Twrite { + fid: u32, + offset: u64, + count: u32, + data: Vec, + }, + Treaddir { + fid: u32, + offset: u64, + count: u32, + }, + Tgetattr { + fid: u32, + request_mask: u64, + }, + Tsetattr { + fid: u32, + valid: u32, + mode: u32, + uid: u32, + gid: u32, + size: u64, + atime_sec: u64, + atime_nsec: u64, + mtime_sec: u64, + mtime_nsec: u64, + }, + Tclunk { + fid: u32, + }, + Tflush { + oldtag: u16, + }, + Tmkdir { + dfid: u32, + name: String, + mode: u32, + gid: u32, + }, + Trenameat { + olddirfid: u32, + oldname: String, + newdirfid: u32, + newname: String, + }, + Tunlinkat { + dirfid: u32, + name: String, + flags: u32, + }, + Tfsync { + fid: u32, + }, +} + +/// Parse a T-message body (after header has been read). +pub fn parse_request(msg_type: u8, body: &[u8]) -> Option { + let mut r = ByteReader::new(body); + match msg_type { + P9_TVERSION => { + let msize = r.get_u32()?; + let version = r.get_string()?; + Some(P9Request::Tversion { msize, version }) + } + P9_TATTACH => { + let fid = r.get_u32()?; + let afid = r.get_u32()?; + let uname = r.get_string()?; + let aname = r.get_string()?; + Some(P9Request::Tattach { + fid, + afid, + uname, + aname, + }) + } + P9_TWALK => { + let fid = r.get_u32()?; + let newfid = r.get_u32()?; + let nwname = r.get_u16()?; + let mut names = Vec::with_capacity(nwname as usize); + for _ in 0..nwname { + names.push(r.get_string()?); + } + Some(P9Request::Twalk { fid, newfid, names }) + } + P9_TLOPEN => { + let fid = r.get_u32()?; + let flags = r.get_u32()?; + Some(P9Request::Tlopen { fid, flags }) + } + P9_TLCREATE => { + let fid = r.get_u32()?; + let name = r.get_string()?; + let flags = r.get_u32()?; + let mode = r.get_u32()?; + let gid = r.get_u32()?; + Some(P9Request::Tlcreate { + fid, + name, + flags, + mode, + gid, + }) + } + P9_TREAD => { + let fid = r.get_u32()?; + let offset = r.get_u64()?; + let count = r.get_u32()?; + Some(P9Request::Tread { fid, offset, count }) + } + P9_TWRITE => { + let fid = r.get_u32()?; + let offset = r.get_u64()?; + let count = r.get_u32()?; + let data = r.get_bytes(count as usize)?; + Some(P9Request::Twrite { + fid, + offset, + count, + data, + }) + } + P9_TREADDIR => { + let fid = r.get_u32()?; + let offset = r.get_u64()?; + let count = r.get_u32()?; + Some(P9Request::Treaddir { fid, offset, count }) + } + P9_TGETATTR => { + let fid = r.get_u32()?; + let request_mask = r.get_u64()?; + Some(P9Request::Tgetattr { fid, request_mask }) + } + P9_TSETATTR => { + let fid = r.get_u32()?; + let valid = r.get_u32()?; + let mode = r.get_u32()?; + let uid = r.get_u32()?; + let gid = r.get_u32()?; + let size = r.get_u64()?; + let atime_sec = r.get_u64()?; + let atime_nsec = r.get_u64()?; + let mtime_sec = r.get_u64()?; + let mtime_nsec = r.get_u64()?; + Some(P9Request::Tsetattr { + fid, + valid, + mode, + uid, + gid, + size, + atime_sec, + atime_nsec, + mtime_sec, + mtime_nsec, + }) + } + P9_TCLUNK => { + let fid = r.get_u32()?; + Some(P9Request::Tclunk { fid }) + } + P9_TFLUSH => { + let oldtag = r.get_u16()?; + Some(P9Request::Tflush { oldtag }) + } + P9_TMKDIR => { + let dfid = r.get_u32()?; + let name = r.get_string()?; + let mode = r.get_u32()?; + let gid = r.get_u32()?; + Some(P9Request::Tmkdir { + dfid, + name, + mode, + gid, + }) + } + P9_TRENAMEAT => { + let olddirfid = r.get_u32()?; + let oldname = r.get_string()?; + let newdirfid = r.get_u32()?; + let newname = r.get_string()?; + Some(P9Request::Trenameat { + olddirfid, + oldname, + newdirfid, + newname, + }) + } + P9_TUNLINKAT => { + let dirfid = r.get_u32()?; + let name = r.get_string()?; + let flags = r.get_u32()?; + Some(P9Request::Tunlinkat { + dirfid, + name, + flags, + }) + } + P9_TFSYNC => { + let fid = r.get_u32()?; + Some(P9Request::Tfsync { fid }) + } + _ => None, + } +} + +// -- P9Attr: Rgetattr response payload -- + +/// File attributes for Rgetattr. +#[derive(Debug, Clone)] +pub struct P9Attr { + pub valid: u64, + pub qid: Qid, + pub mode: u32, + pub uid: u32, + pub gid: u32, + pub nlink: u64, + pub rdev: u64, + pub size: u64, + pub blksize: u64, + pub blocks: u64, + pub atime_sec: u64, + pub atime_nsec: u64, + pub mtime_sec: u64, + pub mtime_nsec: u64, + pub ctime_sec: u64, + pub ctime_nsec: u64, + pub btime_sec: u64, + pub btime_nsec: u64, + pub gen: u64, + pub data_version: u64, +} + +impl P9Attr { + pub fn write_to(&self, w: &mut ByteWriter) { + w.put_u64(self.valid); + self.qid.write_to(w); + w.put_u32(self.mode); + w.put_u32(self.uid); + w.put_u32(self.gid); + w.put_u64(self.nlink); + w.put_u64(self.rdev); + w.put_u64(self.size); + w.put_u64(self.blksize); + w.put_u64(self.blocks); + w.put_u64(self.atime_sec); + w.put_u64(self.atime_nsec); + w.put_u64(self.mtime_sec); + w.put_u64(self.mtime_nsec); + w.put_u64(self.ctime_sec); + w.put_u64(self.ctime_nsec); + w.put_u64(self.btime_sec); + w.put_u64(self.btime_nsec); + w.put_u64(self.gen); + w.put_u64(self.data_version); + } +} + +// -- Response builders -- + +/// Build an Rlerror response body (after header). +pub fn write_rlerror(w: &mut ByteWriter, ecode: u32) { + w.put_u32(ecode); +} + +/// Build an Rversion response body. +pub fn write_rversion(w: &mut ByteWriter, msize: u32, version: &str) { + w.put_u32(msize); + w.put_string(version); +} + +/// Build an Rattach response body. +pub fn write_rattach(w: &mut ByteWriter, qid: &Qid) { + qid.write_to(w); +} + +/// Build an Rwalk response body. +pub fn write_rwalk(w: &mut ByteWriter, qids: &[Qid]) { + w.put_u16(qids.len() as u16); + for qid in qids { + qid.write_to(w); + } +} + +/// Build an Rlopen response body. +pub fn write_rlopen(w: &mut ByteWriter, qid: &Qid, iounit: u32) { + qid.write_to(w); + w.put_u32(iounit); +} + +/// Build an Rlcreate response body. +pub fn write_rlcreate(w: &mut ByteWriter, qid: &Qid, iounit: u32) { + qid.write_to(w); + w.put_u32(iounit); +} + +/// Build an Rread response body. +pub fn write_rread(w: &mut ByteWriter, data: &[u8]) { + w.put_u32(data.len() as u32); + w.put_raw(data); +} + +/// Build an Rwrite response body. +pub fn write_rwrite(w: &mut ByteWriter, count: u32) { + w.put_u32(count); +} + +/// Build an Rreaddir response body. +pub fn write_rreaddir(w: &mut ByteWriter, data: &[u8]) { + w.put_u32(data.len() as u32); + w.put_raw(data); +} + +/// Build an Rgetattr response body. +pub fn write_rgetattr(w: &mut ByteWriter, attr: &P9Attr) { + attr.write_to(w); +} + +/// Build an Rclunk response body (empty). +pub fn write_rclunk(_w: &mut ByteWriter) { + // No body. +} + +/// Build an Rflush response body (empty). +pub fn write_rflush(_w: &mut ByteWriter) { + // No body. +} + +/// Build an Rsetattr response body (empty). +pub fn write_rsetattr(_w: &mut ByteWriter) { + // No body. +} + +/// Build an Rmkdir response body. +pub fn write_rmkdir(w: &mut ByteWriter, qid: &Qid) { + qid.write_to(w); +} + +/// Build an Rrenameat response body (empty). +pub fn write_rrenameat(_w: &mut ByteWriter) { + // No body. +} + +/// Build an Runlinkat response body (empty). +pub fn write_runlinkat(_w: &mut ByteWriter) { + // No body. +} + +/// Build an Rfsync response body (empty). +pub fn write_rfsync(_w: &mut ByteWriter) { + // No body. +} + +// -- ByteReader: sequential reader over a byte slice -- + +/// Cursor for reading fields from a byte buffer. +pub struct ByteReader<'a> { + data: &'a [u8], + pos: usize, +} + +impl<'a> ByteReader<'a> { + pub fn new(data: &'a [u8]) -> Self { + ByteReader { data, pos: 0 } + } + + pub fn remaining(&self) -> usize { + self.data.len().saturating_sub(self.pos) + } + + pub fn get_u8(&mut self) -> Option { + if self.pos + 1 > self.data.len() { + return None; + } + let val = self.data[self.pos]; + self.pos += 1; + Some(val) + } + + pub fn get_u16(&mut self) -> Option { + if self.pos + 2 > self.data.len() { + return None; + } + let val = u16::from_le_bytes([self.data[self.pos], self.data[self.pos + 1]]); + self.pos += 2; + Some(val) + } + + pub fn get_u32(&mut self) -> Option { + if self.pos + 4 > self.data.len() { + return None; + } + let val = u32::from_le_bytes([ + self.data[self.pos], + self.data[self.pos + 1], + self.data[self.pos + 2], + self.data[self.pos + 3], + ]); + self.pos += 4; + Some(val) + } + + pub fn get_u64(&mut self) -> Option { + if self.pos + 8 > self.data.len() { + return None; + } + let val = u64::from_le_bytes([ + self.data[self.pos], + self.data[self.pos + 1], + self.data[self.pos + 2], + self.data[self.pos + 3], + self.data[self.pos + 4], + self.data[self.pos + 5], + self.data[self.pos + 6], + self.data[self.pos + 7], + ]); + self.pos += 8; + Some(val) + } + + /// Read a 9P string: length[2] + data[length]. + pub fn get_string(&mut self) -> Option { + let len = self.get_u16()? as usize; + let bytes = self.get_bytes(len)?; + String::from_utf8(bytes).ok() + } + + pub fn get_bytes(&mut self, count: usize) -> Option> { + if self.pos + count > self.data.len() { + return None; + } + let val = self.data[self.pos..self.pos + count].to_vec(); + self.pos += count; + Some(val) + } +} + +// -- ByteWriter: sequential writer into a byte buffer -- + +/// Cursor for writing fields into a growable byte buffer. +pub struct ByteWriter { + data: Vec, +} + +impl Default for ByteWriter { + fn default() -> Self { + Self::new() + } +} + +impl ByteWriter { + pub fn new() -> Self { + ByteWriter { data: Vec::new() } + } + + pub fn with_capacity(cap: usize) -> Self { + ByteWriter { + data: Vec::with_capacity(cap), + } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + pub fn into_bytes(self) -> Vec { + self.data + } + + pub fn as_bytes(&self) -> &[u8] { + &self.data + } + + pub fn put_u8(&mut self, val: u8) { + self.data.push(val); + } + + pub fn put_u16(&mut self, val: u16) { + self.data.extend_from_slice(&val.to_le_bytes()); + } + + pub fn put_u32(&mut self, val: u32) { + self.data.extend_from_slice(&val.to_le_bytes()); + } + + pub fn put_u64(&mut self, val: u64) { + self.data.extend_from_slice(&val.to_le_bytes()); + } + + /// Write a 9P string: length[2] + data[length]. + pub fn put_string(&mut self, s: &str) { + self.put_u16(s.len() as u16); + self.data.extend_from_slice(s.as_bytes()); + } + + pub fn put_raw(&mut self, data: &[u8]) { + self.data.extend_from_slice(data); + } + + /// Patch a u32 at the given byte offset (used for message size fixup). + pub fn patch_u32(&mut self, offset: usize, val: u32) { + let bytes = val.to_le_bytes(); + self.data[offset..offset + 4].copy_from_slice(&bytes); + } +} + +/// Build a complete 9P response message (header + body). +/// +/// `msg_type` is the R-message type constant. +/// `tag` is the request tag to echo back. +/// `body_fn` writes the body fields into the ByteWriter. +pub fn build_response(msg_type: u8, tag: u16, body_fn: impl FnOnce(&mut ByteWriter)) -> Vec { + let mut w = ByteWriter::with_capacity(128); + // Reserve space for the size field. + w.put_u32(0); + w.put_u8(msg_type); + w.put_u16(tag); + body_fn(&mut w); + // Patch the size field with the total message length. + let total = w.len() as u32; + w.patch_u32(0, total); + w.into_bytes() +} + +#[cfg(test)] +mod tests { + use super::*; + + // -- ByteReader tests -- + + #[test] + fn test_reader_u8() { + let data = [0x42]; + let mut r = ByteReader::new(&data); + assert_eq!(r.get_u8(), Some(0x42)); + assert_eq!(r.get_u8(), None); + } + + #[test] + fn test_reader_u16() { + let data = 0x1234u16.to_le_bytes(); + let mut r = ByteReader::new(&data); + assert_eq!(r.get_u16(), Some(0x1234)); + assert_eq!(r.remaining(), 0); + } + + #[test] + fn test_reader_u32() { + let data = 0xDEADBEEFu32.to_le_bytes(); + let mut r = ByteReader::new(&data); + assert_eq!(r.get_u32(), Some(0xDEADBEEF)); + } + + #[test] + fn test_reader_u64() { + let data = 0x0102030405060708u64.to_le_bytes(); + let mut r = ByteReader::new(&data); + assert_eq!(r.get_u64(), Some(0x0102030405060708)); + } + + #[test] + fn test_reader_string() { + let mut buf = Vec::new(); + buf.extend_from_slice(&5u16.to_le_bytes()); + buf.extend_from_slice(b"hello"); + let mut r = ByteReader::new(&buf); + assert_eq!(r.get_string(), Some("hello".to_string())); + } + + #[test] + fn test_reader_empty_string() { + let buf = 0u16.to_le_bytes(); + let mut r = ByteReader::new(&buf); + assert_eq!(r.get_string(), Some(String::new())); + } + + #[test] + fn test_reader_truncated_returns_none() { + let data = [0x01]; // Only 1 byte, but asking for u32. + let mut r = ByteReader::new(&data); + assert_eq!(r.get_u32(), None); + } + + #[test] + fn test_reader_bytes() { + let data = [1, 2, 3, 4, 5]; + let mut r = ByteReader::new(&data); + assert_eq!(r.get_bytes(3), Some(vec![1, 2, 3])); + assert_eq!(r.get_bytes(3), None); // Only 2 remaining. + assert_eq!(r.get_bytes(2), Some(vec![4, 5])); + } + + // -- ByteWriter tests -- + + #[test] + fn test_writer_roundtrip_u32() { + let mut w = ByteWriter::new(); + w.put_u32(0xCAFEBABE); + let mut r = ByteReader::new(w.as_bytes()); + assert_eq!(r.get_u32(), Some(0xCAFEBABE)); + } + + #[test] + fn test_writer_string() { + let mut w = ByteWriter::new(); + w.put_string("test"); + let mut r = ByteReader::new(w.as_bytes()); + assert_eq!(r.get_string(), Some("test".to_string())); + } + + #[test] + fn test_writer_patch_u32() { + let mut w = ByteWriter::new(); + w.put_u32(0); // Placeholder. + w.put_u8(0xFF); + w.patch_u32(0, 42); + assert_eq!(w.as_bytes()[0..4], 42u32.to_le_bytes()); + assert_eq!(w.as_bytes()[4], 0xFF); + } + + #[test] + fn test_writer_len() { + let mut w = ByteWriter::new(); + assert_eq!(w.len(), 0); + w.put_u32(0); + assert_eq!(w.len(), 4); + w.put_string("hi"); + assert_eq!(w.len(), 4 + 2 + 2); // u32 + u16_len + "hi" + } + + // -- Header tests -- + + #[test] + fn test_header_roundtrip() { + let hdr = P9Header { + size: 23, + msg_type: P9_TVERSION, + tag: P9_NOTAG, + }; + let mut w = ByteWriter::new(); + hdr.write_to(&mut w); + assert_eq!(w.len(), P9_HEADER_SIZE); + + let mut r = ByteReader::new(w.as_bytes()); + let parsed = P9Header::read_from(&mut r).unwrap(); + assert_eq!(parsed, hdr); + } + + // -- QID tests -- + + #[test] + fn test_qid_roundtrip() { + let qid = Qid { + qtype: QT_DIR, + version: 12345, + path: 0xDEAD_BEEF_CAFE, + }; + let mut w = ByteWriter::new(); + qid.write_to(&mut w); + assert_eq!(w.len(), QID_SIZE); + + let mut r = ByteReader::new(w.as_bytes()); + let parsed = Qid::read_from(&mut r).unwrap(); + assert_eq!(parsed, qid); + } + + #[test] + fn test_qid_file() { + let qid = Qid { + qtype: QT_FILE, + version: 0, + path: 1, + }; + let mut w = ByteWriter::new(); + qid.write_to(&mut w); + let mut r = ByteReader::new(w.as_bytes()); + let parsed = Qid::read_from(&mut r).unwrap(); + assert_eq!(parsed.qtype, QT_FILE); + } + + // -- Request parsing tests -- + + #[test] + fn test_parse_tversion() { + let mut w = ByteWriter::new(); + w.put_u32(8192); // msize + w.put_string("9P2000.L"); + let req = parse_request(P9_TVERSION, w.as_bytes()).unwrap(); + match req { + P9Request::Tversion { msize, version } => { + assert_eq!(msize, 8192); + assert_eq!(version, "9P2000.L"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tattach() { + let mut w = ByteWriter::new(); + w.put_u32(0); // fid + w.put_u32(P9_NOFID); // afid + w.put_string("root"); + w.put_string("/share"); + let req = parse_request(P9_TATTACH, w.as_bytes()).unwrap(); + match req { + P9Request::Tattach { + fid, + afid, + uname, + aname, + } => { + assert_eq!(fid, 0); + assert_eq!(afid, P9_NOFID); + assert_eq!(uname, "root"); + assert_eq!(aname, "/share"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_twalk_empty() { + let mut w = ByteWriter::new(); + w.put_u32(0); // fid + w.put_u32(1); // newfid + w.put_u16(0); // nwname = 0 + let req = parse_request(P9_TWALK, w.as_bytes()).unwrap(); + match req { + P9Request::Twalk { fid, newfid, names } => { + assert_eq!(fid, 0); + assert_eq!(newfid, 1); + assert!(names.is_empty()); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_twalk_multi() { + let mut w = ByteWriter::new(); + w.put_u32(0); + w.put_u32(1); + w.put_u16(3); + w.put_string("usr"); + w.put_string("local"); + w.put_string("bin"); + let req = parse_request(P9_TWALK, w.as_bytes()).unwrap(); + match req { + P9Request::Twalk { names, .. } => { + assert_eq!(names, vec!["usr", "local", "bin"]); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tlopen() { + let mut w = ByteWriter::new(); + w.put_u32(5); // fid + w.put_u32(0); // O_RDONLY + let req = parse_request(P9_TLOPEN, w.as_bytes()).unwrap(); + match req { + P9Request::Tlopen { fid, flags } => { + assert_eq!(fid, 5); + assert_eq!(flags, 0); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tread() { + let mut w = ByteWriter::new(); + w.put_u32(3); // fid + w.put_u64(100); // offset + w.put_u32(4096); // count + let req = parse_request(P9_TREAD, w.as_bytes()).unwrap(); + match req { + P9Request::Tread { fid, offset, count } => { + assert_eq!(fid, 3); + assert_eq!(offset, 100); + assert_eq!(count, 4096); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_twrite() { + let mut w = ByteWriter::new(); + w.put_u32(3); // fid + w.put_u64(0); // offset + w.put_u32(5); // count + w.put_raw(b"hello"); // data + let req = parse_request(P9_TWRITE, w.as_bytes()).unwrap(); + match req { + P9Request::Twrite { + fid, + offset, + count, + data, + } => { + assert_eq!(fid, 3); + assert_eq!(offset, 0); + assert_eq!(count, 5); + assert_eq!(data, b"hello"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tclunk() { + let mut w = ByteWriter::new(); + w.put_u32(7); + let req = parse_request(P9_TCLUNK, w.as_bytes()).unwrap(); + match req { + P9Request::Tclunk { fid } => assert_eq!(fid, 7), + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tflush() { + let mut w = ByteWriter::new(); + w.put_u16(42); + let req = parse_request(P9_TFLUSH, w.as_bytes()).unwrap(); + match req { + P9Request::Tflush { oldtag } => assert_eq!(oldtag, 42), + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tgetattr() { + let mut w = ByteWriter::new(); + w.put_u32(1); + w.put_u64(0x3FFF); // request_mask: all valid bits + let req = parse_request(P9_TGETATTR, w.as_bytes()).unwrap(); + match req { + P9Request::Tgetattr { fid, request_mask } => { + assert_eq!(fid, 1); + assert_eq!(request_mask, 0x3FFF); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_treaddir() { + let mut w = ByteWriter::new(); + w.put_u32(2); // fid + w.put_u64(0); // offset + w.put_u32(8192); // count + let req = parse_request(P9_TREADDIR, w.as_bytes()).unwrap(); + match req { + P9Request::Treaddir { fid, offset, count } => { + assert_eq!(fid, 2); + assert_eq!(offset, 0); + assert_eq!(count, 8192); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tmkdir() { + let mut w = ByteWriter::new(); + w.put_u32(1); // dfid + w.put_string("newdir"); + w.put_u32(0o755); // mode + w.put_u32(0); // gid + let req = parse_request(P9_TMKDIR, w.as_bytes()).unwrap(); + match req { + P9Request::Tmkdir { + dfid, + name, + mode, + gid, + } => { + assert_eq!(dfid, 1); + assert_eq!(name, "newdir"); + assert_eq!(mode, 0o755); + assert_eq!(gid, 0); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tunlinkat() { + let mut w = ByteWriter::new(); + w.put_u32(1); // dirfid + w.put_string("oldfile"); + w.put_u32(0); // flags + let req = parse_request(P9_TUNLINKAT, w.as_bytes()).unwrap(); + match req { + P9Request::Tunlinkat { + dirfid, + name, + flags, + } => { + assert_eq!(dirfid, 1); + assert_eq!(name, "oldfile"); + assert_eq!(flags, 0); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_trenameat() { + let mut w = ByteWriter::new(); + w.put_u32(1); // olddirfid + w.put_string("old.txt"); + w.put_u32(2); // newdirfid + w.put_string("new.txt"); + let req = parse_request(P9_TRENAMEAT, w.as_bytes()).unwrap(); + match req { + P9Request::Trenameat { + olddirfid, + oldname, + newdirfid, + newname, + } => { + assert_eq!(olddirfid, 1); + assert_eq!(oldname, "old.txt"); + assert_eq!(newdirfid, 2); + assert_eq!(newname, "new.txt"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tfsync() { + let mut w = ByteWriter::new(); + w.put_u32(5); + let req = parse_request(P9_TFSYNC, w.as_bytes()).unwrap(); + match req { + P9Request::Tfsync { fid } => assert_eq!(fid, 5), + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tlcreate() { + let mut w = ByteWriter::new(); + w.put_u32(1); // fid + w.put_string("newfile.txt"); + w.put_u32(0x42); // flags (O_CREAT|O_RDWR) + w.put_u32(0o644); // mode + w.put_u32(0); // gid + let req = parse_request(P9_TLCREATE, w.as_bytes()).unwrap(); + match req { + P9Request::Tlcreate { + fid, + name, + flags, + mode, + gid, + } => { + assert_eq!(fid, 1); + assert_eq!(name, "newfile.txt"); + assert_eq!(flags, 0x42); + assert_eq!(mode, 0o644); + assert_eq!(gid, 0); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_tsetattr() { + let mut w = ByteWriter::new(); + w.put_u32(3); // fid + w.put_u32(0x01); // valid (mode) + w.put_u32(0o755); // mode + w.put_u32(0); // uid + w.put_u32(0); // gid + w.put_u64(0); // size + w.put_u64(0); // atime_sec + w.put_u64(0); // atime_nsec + w.put_u64(0); // mtime_sec + w.put_u64(0); // mtime_nsec + let req = parse_request(P9_TSETATTR, w.as_bytes()).unwrap(); + match req { + P9Request::Tsetattr { + fid, valid, mode, .. + } => { + assert_eq!(fid, 3); + assert_eq!(valid, 0x01); + assert_eq!(mode, 0o755); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_parse_unknown_type_returns_none() { + assert!(parse_request(0xFF, &[]).is_none()); + } + + // -- Response builder tests -- + + #[test] + fn test_build_rversion() { + let msg = build_response(P9_RVERSION, P9_NOTAG, |w| { + write_rversion(w, 8192, "9P2000.L"); + }); + let mut r = ByteReader::new(&msg); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RVERSION); + assert_eq!(hdr.tag, P9_NOTAG); + assert_eq!(hdr.size as usize, msg.len()); + + let msize = r.get_u32().unwrap(); + let version = r.get_string().unwrap(); + assert_eq!(msize, 8192); + assert_eq!(version, "9P2000.L"); + } + + #[test] + fn test_build_rlerror() { + let msg = build_response(P9_RLERROR, 1, |w| { + write_rlerror(w, 2); // ENOENT + }); + let mut r = ByteReader::new(&msg); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RLERROR); + assert_eq!(hdr.tag, 1); + let ecode = r.get_u32().unwrap(); + assert_eq!(ecode, 2); + } + + #[test] + fn test_build_rwalk() { + let qids = vec![ + Qid { + qtype: QT_DIR, + version: 1, + path: 100, + }, + Qid { + qtype: QT_FILE, + version: 2, + path: 200, + }, + ]; + let msg = build_response(P9_RWALK, 5, |w| { + write_rwalk(w, &qids); + }); + let mut r = ByteReader::new(&msg); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RWALK); + let nwqid = r.get_u16().unwrap(); + assert_eq!(nwqid, 2); + let q1 = Qid::read_from(&mut r).unwrap(); + assert_eq!(q1.path, 100); + let q2 = Qid::read_from(&mut r).unwrap(); + assert_eq!(q2.path, 200); + } + + #[test] + fn test_build_rread() { + let msg = build_response(P9_RREAD, 3, |w| { + write_rread(w, b"file data"); + }); + let mut r = ByteReader::new(&msg); + let _hdr = P9Header::read_from(&mut r).unwrap(); + let count = r.get_u32().unwrap(); + assert_eq!(count, 9); + let data = r.get_bytes(count as usize).unwrap(); + assert_eq!(data, b"file data"); + } + + #[test] + fn test_build_rwrite() { + let msg = build_response(P9_RWRITE, 3, |w| { + write_rwrite(w, 42); + }); + let mut r = ByteReader::new(&msg); + let _hdr = P9Header::read_from(&mut r).unwrap(); + let count = r.get_u32().unwrap(); + assert_eq!(count, 42); + } + + #[test] + fn test_build_rlopen() { + let qid = Qid { + qtype: QT_FILE, + version: 1, + path: 42, + }; + let msg = build_response(P9_RLOPEN, 2, |w| { + write_rlopen(w, &qid, 4096); + }); + let mut r = ByteReader::new(&msg); + let _hdr = P9Header::read_from(&mut r).unwrap(); + let q = Qid::read_from(&mut r).unwrap(); + assert_eq!(q, qid); + let iounit = r.get_u32().unwrap(); + assert_eq!(iounit, 4096); + } + + #[test] + fn test_build_rattach() { + let qid = Qid { + qtype: QT_DIR, + version: 0, + path: 1, + }; + let msg = build_response(P9_RATTACH, 0, |w| { + write_rattach(w, &qid); + }); + let mut r = ByteReader::new(&msg); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RATTACH); + let q = Qid::read_from(&mut r).unwrap(); + assert_eq!(q, qid); + } + + #[test] + fn test_build_response_size_correct() { + // Rclunk is header-only (7 bytes total). + let msg = build_response(P9_RCLUNK, 10, |w| { + write_rclunk(w); + }); + assert_eq!(msg.len(), P9_HEADER_SIZE); + let mut r = ByteReader::new(&msg); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.size as usize, P9_HEADER_SIZE); + } + + #[test] + fn test_build_rmkdir() { + let qid = Qid { + qtype: QT_DIR, + version: 3, + path: 99, + }; + let msg = build_response(P9_RMKDIR, 7, |w| { + write_rmkdir(w, &qid); + }); + let mut r = ByteReader::new(&msg); + let hdr = P9Header::read_from(&mut r).unwrap(); + assert_eq!(hdr.msg_type, P9_RMKDIR); + let q = Qid::read_from(&mut r).unwrap(); + assert_eq!(q, qid); + } +} diff --git a/src/vmm/src/windows/devices/virtio/queue.rs b/src/vmm/src/windows/devices/virtio/queue.rs new file mode 100644 index 000000000..eecd5f642 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/queue.rs @@ -0,0 +1,699 @@ +//! Split virtqueue implementation (virtio spec v1.2 Section 2.7). +//! +//! A split virtqueue consists of three regions in guest memory: +//! - Descriptor table: array of buffer descriptors +//! - Available ring: guest-to-device buffer indices +//! - Used ring: device-to-guest completion notifications + +use super::super::super::error::{Result, WkrunError}; + +/// Abstraction over guest physical memory for cross-platform testing. +pub trait GuestMemoryAccessor { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()>; + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()>; +} + +/// Extension methods for reading typed values from guest memory. +trait GuestMemoryExt: GuestMemoryAccessor { + fn read_u16(&self, addr: u64) -> Result { + let mut buf = [0u8; 2]; + self.read_at(addr, &mut buf)?; + Ok(u16::from_le_bytes(buf)) + } + + fn read_u32(&self, addr: u64) -> Result { + let mut buf = [0u8; 4]; + self.read_at(addr, &mut buf)?; + Ok(u32::from_le_bytes(buf)) + } + + fn read_u64(&self, addr: u64) -> Result { + let mut buf = [0u8; 8]; + self.read_at(addr, &mut buf)?; + Ok(u64::from_le_bytes(buf)) + } + + fn write_u16(&self, addr: u64, val: u16) -> Result<()> { + self.write_at(addr, &val.to_le_bytes()) + } + + fn write_u32(&self, addr: u64, val: u32) -> Result<()> { + self.write_at(addr, &val.to_le_bytes()) + } +} + +impl GuestMemoryExt for T {} + +// Descriptor table entry layout (virtio spec 2.7.5). +const DESC_ADDR_OFFSET: u64 = 0; +const DESC_LEN_OFFSET: u64 = 8; +const DESC_FLAGS_OFFSET: u64 = 12; +const DESC_NEXT_OFFSET: u64 = 14; +const DESC_SIZE: u64 = 16; + +/// Descriptor flag: buffer is device-writable (for reads from device). +const VIRTQ_DESC_F_WRITE: u16 = 2; +/// Descriptor flag: next field is valid (chained descriptor). +const VIRTQ_DESC_F_NEXT: u16 = 1; + +/// A single descriptor from the descriptor table. +#[derive(Debug, Clone, Copy)] +pub struct Descriptor { + /// Guest physical address of the buffer. + pub addr: u64, + /// Length of the buffer in bytes. + pub len: u32, + /// Descriptor flags. + pub flags: u16, + /// Next descriptor index (valid only if VIRTQ_DESC_F_NEXT is set). + pub next: u16, +} + +impl Descriptor { + /// Whether the buffer is device-writable (guest reads from it). + pub fn is_write(&self) -> bool { + self.flags & VIRTQ_DESC_F_WRITE != 0 + } + + /// Whether there is a next descriptor in the chain. + pub fn has_next(&self) -> bool { + self.flags & VIRTQ_DESC_F_NEXT != 0 + } +} + +/// A split virtqueue. +pub struct Virtqueue { + /// Maximum queue size (device sets this). + max_size: u16, + /// Negotiated queue size (driver sets this, must be <= max_size and power of 2). + size: u16, + /// Whether the queue is ready for use. + ready: bool, + /// Guest physical address of the descriptor table. + desc_table_addr: u64, + /// Guest physical address of the available ring. + avail_ring_addr: u64, + /// Guest physical address of the used ring. + used_ring_addr: u64, + /// Last available index consumed by the device. + last_avail_idx: u16, +} + +impl Virtqueue { + /// Create a new virtqueue with the given maximum size. + pub fn new(max_size: u16) -> Self { + Virtqueue { + max_size, + size: 0, + ready: false, + desc_table_addr: 0, + avail_ring_addr: 0, + used_ring_addr: 0, + last_avail_idx: 0, + } + } + + /// Get the maximum queue size. + pub fn max_size(&self) -> u16 { + self.max_size + } + + /// Get the current queue size. + pub fn size(&self) -> u16 { + self.size + } + + /// Set the queue size (called by driver during setup). + pub fn set_size(&mut self, size: u16) { + self.size = size; + } + + /// Whether the queue is ready for I/O. + pub fn is_ready(&self) -> bool { + self.ready + } + + /// Mark the queue as ready. + pub fn set_ready(&mut self, ready: bool) { + self.ready = ready; + } + + /// Set the descriptor table address. + pub fn set_desc_table(&mut self, addr: u64) { + self.desc_table_addr = addr; + } + + /// Set the available ring address. + pub fn set_avail_ring(&mut self, addr: u64) { + self.avail_ring_addr = addr; + } + + /// Set the used ring address. + pub fn set_used_ring(&mut self, addr: u64) { + self.used_ring_addr = addr; + } + + /// Read a descriptor from the descriptor table by index. + fn read_descriptor( + &self, + index: u16, + mem: &(impl GuestMemoryAccessor + ?Sized), + ) -> Result { + if index >= self.size { + return Err(WkrunError::Device(format!( + "descriptor index {} out of bounds (queue size {})", + index, self.size + ))); + } + let addr = self.desc_table_addr + (index as u64) * DESC_SIZE; + Ok(Descriptor { + addr: mem.read_u64(addr + DESC_ADDR_OFFSET)?, + len: mem.read_u32(addr + DESC_LEN_OFFSET)?, + flags: mem.read_u16(addr + DESC_FLAGS_OFFSET)?, + next: mem.read_u16(addr + DESC_NEXT_OFFSET)?, + }) + } + + /// Pop the next available descriptor chain head index, if any. + /// + /// Returns `None` if no new buffers are available. + pub fn pop_avail(&mut self, mem: &(impl GuestMemoryAccessor + ?Sized)) -> Result> { + if !self.ready || self.size == 0 { + return Ok(None); + } + + // Avail ring layout: flags(u16) + idx(u16) + ring[size](u16 each) + let avail_idx = mem.read_u16(self.avail_ring_addr + 2)?; + + if self.last_avail_idx == avail_idx { + return Ok(None); // No new buffers. + } + + let ring_offset = 4 + (self.last_avail_idx % self.size) as u64 * 2; + let head = mem.read_u16(self.avail_ring_addr + ring_offset)?; + + self.last_avail_idx = self.last_avail_idx.wrapping_add(1); + Ok(Some(head)) + } + + /// Read an entire descriptor chain starting from the given head index. + /// + /// Returns the chain of descriptors. Detects cycles by limiting + /// the chain length to the queue size. + pub fn read_desc_chain( + &self, + head: u16, + mem: &(impl GuestMemoryAccessor + ?Sized), + ) -> Result> { + let mut chain = Vec::new(); + let mut index = head; + let max_chain = self.size as usize; + + loop { + if chain.len() >= max_chain { + return Err(WkrunError::Device(format!( + "descriptor chain too long (> {}), possible cycle", + max_chain + ))); + } + + let desc = self.read_descriptor(index, mem)?; + chain.push(desc); + + if !desc.has_next() { + break; + } + index = desc.next; + } + + Ok(chain) + } + + /// Add a used buffer to the used ring. + /// + /// `head` is the descriptor chain head index (from `pop_avail`). + /// `len` is the total bytes written to the descriptor chain. + pub fn add_used( + &mut self, + head: u16, + len: u32, + mem: &(impl GuestMemoryAccessor + ?Sized), + ) -> Result<()> { + if !self.ready || self.size == 0 { + return Err(WkrunError::Device("queue not ready".into())); + } + + // Used ring layout: flags(u16) + idx(u16) + ring[size](id:u32 + len:u32) + let used_idx = mem.read_u16(self.used_ring_addr + 2)?; + let ring_entry_offset = 4 + (used_idx % self.size) as u64 * 8; + let entry_addr = self.used_ring_addr + ring_entry_offset; + + // Write used ring entry: {id: u32, len: u32}. + mem.write_u32(entry_addr, head as u32)?; + mem.write_u32(entry_addr + 4, len)?; + + // Increment used index. + mem.write_u16(self.used_ring_addr + 2, used_idx.wrapping_add(1))?; + + Ok(()) + } + + /// Reset the queue to its initial state. + pub fn reset(&mut self) { + self.size = 0; + self.ready = false; + self.desc_table_addr = 0; + self.avail_ring_addr = 0; + self.used_ring_addr = 0; + self.last_avail_idx = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::cell::RefCell; + + /// Mock guest memory backed by a Vec. + struct MockGuestMemory { + data: RefCell>, + } + + impl MockGuestMemory { + fn new(size: usize) -> Self { + MockGuestMemory { + data: RefCell::new(vec![0u8; size]), + } + } + + fn write_u16_at(&self, addr: u64, val: u16) { + let a = addr as usize; + let bytes = val.to_le_bytes(); + let mut data = self.data.borrow_mut(); + data[a..a + 2].copy_from_slice(&bytes); + } + + fn write_u32_at(&self, addr: u64, val: u32) { + let a = addr as usize; + let bytes = val.to_le_bytes(); + let mut data = self.data.borrow_mut(); + data[a..a + 4].copy_from_slice(&bytes); + } + + fn write_u64_at(&self, addr: u64, val: u64) { + let a = addr as usize; + let bytes = val.to_le_bytes(); + let mut data = self.data.borrow_mut(); + data[a..a + 8].copy_from_slice(&bytes); + } + + fn read_u16_at(&self, addr: u64) -> u16 { + let a = addr as usize; + let data = self.data.borrow(); + u16::from_le_bytes([data[a], data[a + 1]]) + } + + fn read_u32_at(&self, addr: u64) -> u32 { + let a = addr as usize; + let data = self.data.borrow(); + u32::from_le_bytes([data[a], data[a + 1], data[a + 2], data[a + 3]]) + } + } + + impl GuestMemoryAccessor for MockGuestMemory { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + let a = addr as usize; + let data = self.data.borrow(); + if a + buf.len() > data.len() { + return Err(WkrunError::Memory(format!( + "read out of bounds: 0x{:X} + {}", + addr, + buf.len() + ))); + } + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + let a = addr as usize; + let mut mem = self.data.borrow_mut(); + if a + data.len() > mem.len() { + return Err(WkrunError::Memory(format!( + "write out of bounds: 0x{:X} + {}", + addr, + data.len() + ))); + } + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + // Memory layout for tests: + // DESC_TABLE at 0x0000 (256 entries * 16 bytes = 4096 bytes) + // AVAIL_RING at 0x1000 (flags:2 + idx:2 + ring[256]:512 + used_event:2 = 518) + // USED_RING at 0x2000 (flags:2 + idx:2 + ring[256]:(4+4)*256=2048 + avail_event:2 = 2054) + const DESC_TABLE: u64 = 0x0000; + const AVAIL_RING: u64 = 0x1000; + const USED_RING: u64 = 0x2000; + + fn setup_queue(max_size: u16) -> Virtqueue { + let mut q = Virtqueue::new(max_size); + q.set_size(max_size); + q.set_desc_table(DESC_TABLE); + q.set_avail_ring(AVAIL_RING); + q.set_used_ring(USED_RING); + q.set_ready(true); + q + } + + /// Write a descriptor into mock memory. + fn write_descriptor( + mem: &MockGuestMemory, + index: u16, + addr: u64, + len: u32, + flags: u16, + next: u16, + ) { + let base = DESC_TABLE + index as u64 * DESC_SIZE; + mem.write_u64_at(base + DESC_ADDR_OFFSET, addr); + mem.write_u32_at(base + DESC_LEN_OFFSET, len); + mem.write_u16_at(base + DESC_FLAGS_OFFSET, flags); + mem.write_u16_at(base + DESC_NEXT_OFFSET, next); + } + + /// Set the avail ring index and add an entry. + fn push_avail(mem: &MockGuestMemory, ring_idx: u16, desc_head: u16) { + // Write ring entry. + let entry_off = AVAIL_RING + 4 + (ring_idx as u64) * 2; + mem.write_u16_at(entry_off, desc_head); + // Update avail idx. + mem.write_u16_at(AVAIL_RING + 2, ring_idx + 1); + } + + // --- Construction tests --- + + #[test] + fn test_new_queue() { + let q = Virtqueue::new(256); + assert_eq!(q.max_size(), 256); + assert_eq!(q.size(), 0); + assert!(!q.is_ready()); + } + + #[test] + fn test_queue_configuration() { + let mut q = Virtqueue::new(256); + q.set_size(128); + q.set_desc_table(0x1000); + q.set_avail_ring(0x2000); + q.set_used_ring(0x3000); + q.set_ready(true); + assert_eq!(q.size(), 128); + assert!(q.is_ready()); + } + + #[test] + fn test_queue_reset() { + let mut q = setup_queue(256); + assert!(q.is_ready()); + q.reset(); + assert!(!q.is_ready()); + assert_eq!(q.size(), 0); + } + + // --- pop_avail tests --- + + #[test] + fn test_pop_avail_empty() { + let mut q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + // Avail idx = 0, last_avail_idx = 0 -> nothing. + assert!(q.pop_avail(&mem).unwrap().is_none()); + } + + #[test] + fn test_pop_avail_not_ready() { + let mut q = Virtqueue::new(256); + let mem = MockGuestMemory::new(0x4000); + assert!(q.pop_avail(&mem).unwrap().is_none()); + } + + #[test] + fn test_pop_avail_single() { + let mut q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + push_avail(&mem, 0, 42); + + let head = q.pop_avail(&mem).unwrap(); + assert_eq!(head, Some(42)); + + // No more available. + assert!(q.pop_avail(&mem).unwrap().is_none()); + } + + #[test] + fn test_pop_avail_multiple() { + let mut q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + push_avail(&mem, 0, 10); + // Push second: ring[1]=20, idx=2 + mem.write_u16_at(AVAIL_RING + 4 + 2, 20); + mem.write_u16_at(AVAIL_RING + 2, 2); + + assert_eq!(q.pop_avail(&mem).unwrap(), Some(10)); + assert_eq!(q.pop_avail(&mem).unwrap(), Some(20)); + assert!(q.pop_avail(&mem).unwrap().is_none()); + } + + // --- read_desc_chain tests --- + + #[test] + fn test_read_single_descriptor() { + let q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + // Descriptor 0: addr=0x5000, len=512, no flags, no next. + write_descriptor(&mem, 0, 0x5000, 512, 0, 0); + + let chain = q.read_desc_chain(0, &mem).unwrap(); + assert_eq!(chain.len(), 1); + assert_eq!(chain[0].addr, 0x5000); + assert_eq!(chain[0].len, 512); + assert!(!chain[0].is_write()); + assert!(!chain[0].has_next()); + } + + #[test] + fn test_read_chained_descriptors() { + let q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + // Descriptor 0 -> 1 -> 2 (virtio-blk: header -> data -> status). + write_descriptor(&mem, 0, 0x5000, 16, VIRTQ_DESC_F_NEXT, 1); + write_descriptor( + &mem, + 1, + 0x6000, + 512, + VIRTQ_DESC_F_NEXT | VIRTQ_DESC_F_WRITE, + 2, + ); + write_descriptor(&mem, 2, 0x7000, 1, VIRTQ_DESC_F_WRITE, 0); + + let chain = q.read_desc_chain(0, &mem).unwrap(); + assert_eq!(chain.len(), 3); + + // Header (device-readable). + assert_eq!(chain[0].addr, 0x5000); + assert_eq!(chain[0].len, 16); + assert!(!chain[0].is_write()); + assert!(chain[0].has_next()); + + // Data buffer (device-writable). + assert_eq!(chain[1].addr, 0x6000); + assert_eq!(chain[1].len, 512); + assert!(chain[1].is_write()); + + // Status (device-writable). + assert_eq!(chain[2].addr, 0x7000); + assert_eq!(chain[2].len, 1); + assert!(chain[2].is_write()); + assert!(!chain[2].has_next()); + } + + #[test] + fn test_chain_cycle_detection() { + let q = setup_queue(4); + let mem = MockGuestMemory::new(0x4000); + + // Descriptor 0 -> 1 -> 0 (cycle). + write_descriptor(&mem, 0, 0x5000, 16, VIRTQ_DESC_F_NEXT, 1); + write_descriptor(&mem, 1, 0x6000, 512, VIRTQ_DESC_F_NEXT, 0); + + let result = q.read_desc_chain(0, &mem); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("cycle"), "error should mention cycle: {}", err); + } + + #[test] + fn test_descriptor_index_out_of_bounds() { + let q = setup_queue(4); + let mem = MockGuestMemory::new(0x4000); + + let result = q.read_desc_chain(5, &mem); + assert!(result.is_err()); + } + + // --- add_used tests --- + + #[test] + fn test_add_used_single() { + let mut q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + q.add_used(42, 512, &mem).unwrap(); + + // Check used ring: idx should be 1. + let used_idx = mem.read_u16_at(USED_RING + 2); + assert_eq!(used_idx, 1); + + // Check used ring entry: {id=42, len=512}. + let entry_id = mem.read_u32_at(USED_RING + 4); + let entry_len = mem.read_u32_at(USED_RING + 4 + 4); + assert_eq!(entry_id, 42); + assert_eq!(entry_len, 512); + } + + #[test] + fn test_add_used_multiple() { + let mut q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + q.add_used(0, 100, &mem).unwrap(); + q.add_used(3, 200, &mem).unwrap(); + + let used_idx = mem.read_u16_at(USED_RING + 2); + assert_eq!(used_idx, 2); + + // First entry. + assert_eq!(mem.read_u32_at(USED_RING + 4), 0); + assert_eq!(mem.read_u32_at(USED_RING + 8), 100); + + // Second entry. + assert_eq!(mem.read_u32_at(USED_RING + 12), 3); + assert_eq!(mem.read_u32_at(USED_RING + 16), 200); + } + + #[test] + fn test_add_used_not_ready() { + let mut q = Virtqueue::new(256); + let mem = MockGuestMemory::new(0x4000); + assert!(q.add_used(0, 0, &mem).is_err()); + } + + // --- Full round-trip: avail -> process -> used --- + + #[test] + fn test_full_roundtrip() { + let mut q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + // Set up a single-descriptor buffer. + write_descriptor(&mem, 5, 0x8000, 1024, VIRTQ_DESC_F_WRITE, 0); + push_avail(&mem, 0, 5); + + // Pop available. + let head = q.pop_avail(&mem).unwrap().expect("should have buffer"); + assert_eq!(head, 5); + + // Read chain. + let chain = q.read_desc_chain(head, &mem).unwrap(); + assert_eq!(chain.len(), 1); + assert_eq!(chain[0].len, 1024); + + // Complete: add to used. + q.add_used(head, 1024, &mem).unwrap(); + + let used_idx = mem.read_u16_at(USED_RING + 2); + assert_eq!(used_idx, 1); + } + + // --- Wrapping behavior --- + + #[test] + fn test_avail_index_wraps() { + let mut q = setup_queue(4); + let mem = MockGuestMemory::new(0x4000); + + // Simulate avail idx at u16::MAX boundary. + q.last_avail_idx = u16::MAX; + // Set avail ring idx to u16::MAX + 1 = 0 (wraps). + mem.write_u16_at(AVAIL_RING + 2, 0); + + // last_avail_idx (65535) == avail_idx (0 after wrap)? + // No: 65535 != 0, so we should get a buffer. + // Ring offset: (65535 % 4) * 2 = 3 * 2 = 6 -> ring[3] + mem.write_u16_at(AVAIL_RING + 4 + 6, 2); + + let head = q.pop_avail(&mem).unwrap(); + assert_eq!(head, Some(2)); + assert_eq!(q.last_avail_idx, 0); // Wrapped. + } + + // --- Virtio-blk style 3-descriptor chain --- + + #[test] + fn test_virtio_blk_chain() { + let q = setup_queue(256); + let mem = MockGuestMemory::new(0x4000); + + // Header (device-readable): type=IN, sector=0 + write_descriptor(&mem, 0, 0xA000, 16, VIRTQ_DESC_F_NEXT, 1); + // Data buffer (device-writable): 512 bytes + write_descriptor( + &mem, + 1, + 0xB000, + 512, + VIRTQ_DESC_F_NEXT | VIRTQ_DESC_F_WRITE, + 2, + ); + // Status (device-writable): 1 byte + write_descriptor(&mem, 2, 0xC000, 1, VIRTQ_DESC_F_WRITE, 0); + + let chain = q.read_desc_chain(0, &mem).unwrap(); + assert_eq!(chain.len(), 3); + assert!(!chain[0].is_write()); // Header is device-readable. + assert!(chain[1].is_write()); // Data is device-writable. + assert!(chain[2].is_write()); // Status is device-writable. + } + + // --- Descriptor flags --- + + #[test] + fn test_descriptor_flags() { + let desc = Descriptor { + addr: 0, + len: 0, + flags: VIRTQ_DESC_F_WRITE | VIRTQ_DESC_F_NEXT, + next: 1, + }; + assert!(desc.is_write()); + assert!(desc.has_next()); + + let desc2 = Descriptor { + addr: 0, + len: 0, + flags: 0, + next: 0, + }; + assert!(!desc2.is_write()); + assert!(!desc2.has_next()); + } +} diff --git a/src/vmm/src/windows/devices/virtio/rng.rs b/src/vmm/src/windows/devices/virtio/rng.rs new file mode 100644 index 000000000..e723895fa --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/rng.rs @@ -0,0 +1,151 @@ +//! Virtio-rng device (virtio spec v1.2 Section 5.4). +//! +//! Provides entropy to the guest via `/dev/hwrng`. The guest driver +//! submits device-writable buffers; the device fills them with random +//! bytes and returns them on the used ring. + +use super::mmio::VirtioDeviceBackend; +use super::queue::{GuestMemoryAccessor, Virtqueue}; + +/// Virtio device ID for entropy source (spec 5.4). +const VIRTIO_ID_RNG: u32 = 4; + +/// VIRTIO_F_VERSION_1 — bit 32 (feature page 1, bit 0). +const VIRTIO_F_VERSION_1_PAGE1: u32 = 1; + +/// Maximum queue size for the request queue. +const QUEUE_MAX_SIZE: u16 = 256; + +/// Virtio-rng backend. +/// +/// Purely guest-initiated: the guest submits device-writable buffers, +/// the device fills them with random bytes from the host OS entropy pool. +/// No async worker or polling needed. +pub struct VirtioRng { + _priv: (), +} + +impl VirtioRng { + pub fn new() -> Self { + VirtioRng { _priv: () } + } +} + +impl VirtioDeviceBackend for VirtioRng { + fn device_id(&self) -> u32 { + VIRTIO_ID_RNG + } + + fn device_features(&self, page: u32) -> u32 { + match page { + 0 => 0, + 1 => VIRTIO_F_VERSION_1_PAGE1, + _ => 0, + } + } + + fn read_config(&self, _offset: u64) -> u32 { + 0 // No config space. + } + + fn num_queues(&self) -> usize { + 1 + } + + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + QUEUE_MAX_SIZE + } + + fn queue_notify( + &mut self, + _queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + let mut raised = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(e) => { + log::warn!("virtio-rng: failed to read descriptor chain: {}", e); + break; + } + }; + + let mut total_written = 0u32; + for desc in &chain { + if !desc.is_write() { + continue; // Skip device-readable descriptors. + } + + // Fill with random bytes using ThreadRng (infallible, seeds from OS). + let len = desc.len as usize; + let mut buf = vec![0u8; len]; + rand::RngCore::fill_bytes(&mut rand::rng(), &mut buf); + + if let Err(e) = mem.write_at(desc.addr, &buf) { + log::warn!("virtio-rng: failed to write random bytes: {}", e); + break; + } + total_written += desc.len; + } + + if let Err(e) = queue.add_used(head, total_written, mem) { + log::warn!("virtio-rng: failed to add used buffer: {}", e); + break; + } + raised = true; + } + + raised + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_device_id() { + let rng = VirtioRng::new(); + assert_eq!(rng.device_id(), 4); + } + + #[test] + fn test_num_queues() { + let rng = VirtioRng::new(); + assert_eq!(rng.num_queues(), 1); + } + + #[test] + fn test_features_page0() { + let rng = VirtioRng::new(); + assert_eq!(rng.device_features(0), 0); + } + + #[test] + fn test_features_page1_version_1() { + let rng = VirtioRng::new(); + assert_eq!(rng.device_features(1), 1); // VIRTIO_F_VERSION_1 + } + + #[test] + fn test_features_page2_zero() { + let rng = VirtioRng::new(); + assert_eq!(rng.device_features(2), 0); + } + + #[test] + fn test_read_config_returns_zero() { + let rng = VirtioRng::new(); + assert_eq!(rng.read_config(0), 0); + assert_eq!(rng.read_config(4), 0); + } + + #[test] + fn test_queue_max_size() { + let rng = VirtioRng::new(); + assert_eq!(rng.queue_max_size(0), 256); + } +} diff --git a/src/vmm/src/windows/devices/virtio/vsock/connection.rs b/src/vmm/src/windows/devices/virtio/vsock/connection.rs new file mode 100644 index 000000000..54fc06544 --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/vsock/connection.rs @@ -0,0 +1,778 @@ +//! Vsock connection state machine with credit-based flow control. +//! +//! Each vsock connection tracks the state of a bidirectional byte stream +//! between a guest port and a host port. Flow control follows the virtio +//! spec (Section 5.10.6.3): each side advertises buffer space (buf_alloc) +//! and reports bytes consumed (fwd_cnt). The peer computes available +//! send credit as: `peer_buf_alloc - (tx_cnt - peer_fwd_cnt)`. + +use super::packet::{ + VsockHeader, VSOCK_OP_CREDIT_REQUEST, VSOCK_OP_CREDIT_UPDATE, VSOCK_OP_REQUEST, + VSOCK_OP_RESPONSE, VSOCK_OP_RST, VSOCK_OP_RW, VSOCK_OP_SHUTDOWN, VSOCK_SHUTDOWN_RECV, + VSOCK_SHUTDOWN_SEND, +}; + +/// Default buffer space we advertise to the peer (64 KiB). +const DEFAULT_BUF_ALLOC: u32 = 65536; + +/// Connection state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnState { + /// No connection established. + Idle, + /// REQUEST sent/received, waiting for RESPONSE. + Connecting, + /// Data transfer active. + Connected, + /// SHUTDOWN sent or received; draining. + Closing, + /// Connection fully closed. + Closed, +} + +/// A single vsock connection between a guest port and a host port. +pub struct VsockConnection { + state: ConnState, + pub local_cid: u64, + pub local_port: u32, + pub peer_cid: u64, + pub peer_port: u32, + // Our credit: how much buffer space we offer to the peer. + buf_alloc: u32, + // Bytes we have consumed (forwarded to host TCP socket). + fwd_cnt: u32, + // Peer's advertised buffer space. + peer_buf_alloc: u32, + // Peer's forwarded count (bytes peer has consumed). + peer_fwd_cnt: u32, + // Total bytes we have sent to the peer (to compute remaining credit). + tx_cnt: u32, + // Host-to-guest transmit buffer. + tx_buf: Vec, + // Whether the peer has requested a credit update. + credit_update_needed: bool, +} + +impl VsockConnection { + /// Create a new connection in the Idle state. + pub fn new(local_cid: u64, local_port: u32, peer_cid: u64, peer_port: u32) -> Self { + VsockConnection { + state: ConnState::Idle, + local_cid, + local_port, + peer_cid, + peer_port, + buf_alloc: DEFAULT_BUF_ALLOC, + fwd_cnt: 0, + peer_buf_alloc: 0, + peer_fwd_cnt: 0, + tx_cnt: 0, + tx_buf: Vec::new(), + credit_update_needed: false, + } + } + + /// Current connection state. + pub fn state(&self) -> ConnState { + self.state + } + + /// Our advertised buffer space. + pub fn buf_alloc(&self) -> u32 { + self.buf_alloc + } + + /// Bytes we have consumed (forwarded to host side). + pub fn fwd_cnt(&self) -> u32 { + self.fwd_cnt + } + + /// Total bytes we have sent to the peer. + pub fn tx_cnt(&self) -> u32 { + self.tx_cnt + } + + /// Bytes buffered for host-to-guest transmission. + pub fn tx_buf_len(&self) -> usize { + self.tx_buf.len() + } + + /// Available credit to send data to the peer. + /// + /// `peer_buf_alloc - (tx_cnt - peer_fwd_cnt)` per spec 5.10.6.3. + pub fn peer_credit(&self) -> u32 { + let in_flight = self.tx_cnt.wrapping_sub(self.peer_fwd_cnt); + self.peer_buf_alloc.saturating_sub(in_flight) + } + + /// Whether we need to send a credit update to the peer. + pub fn needs_credit_update(&self) -> bool { + self.credit_update_needed + } + + /// Clear the credit update flag. + pub fn clear_credit_update(&mut self) { + self.credit_update_needed = false; + } + + /// Initiate a host-to-guest connection (host-initiated). + /// + /// Transitions Idle -> Connecting and returns a REQUEST header to send + /// to the guest via the RX queue. Returns None if not in Idle state. + pub fn initiate_connect(&mut self) -> Option { + if self.state != ConnState::Idle { + return None; + } + self.state = ConnState::Connecting; + Some(VsockHeader::new_request( + self.local_cid, + self.local_port, + self.peer_cid, + self.peer_port, + self.buf_alloc, + self.fwd_cnt, + )) + } + + /// Handle a REQUEST from the guest. + /// + /// Transitions Idle -> Connected and returns a RESPONSE header. + /// Returns None if the connection is not in Idle state (sends RST instead). + pub fn handle_request(&mut self, hdr: &VsockHeader) -> Option { + if self.state != ConnState::Idle { + return None; + } + + // Record peer's credit info from the REQUEST. + self.peer_buf_alloc = hdr.buf_alloc; + self.peer_fwd_cnt = hdr.fwd_cnt; + self.state = ConnState::Connected; + log::debug!( + "vsock conn ({},{}) → Connected (guest REQUEST, buf_alloc={})", + self.local_port, + self.peer_port, + hdr.buf_alloc + ); + + Some(VsockHeader::new_response( + self.local_cid, + self.local_port, + self.peer_cid, + self.peer_port, + self.buf_alloc, + self.fwd_cnt, + )) + } + + /// Handle an RW (data) packet from the guest. + /// + /// Returns the payload data to forward to the host TCP socket. + /// Updates fwd_cnt. Returns None if not connected. + pub fn handle_rw(&mut self, payload: &[u8]) -> Option> { + if self.state != ConnState::Connected { + return None; + } + + self.fwd_cnt = self.fwd_cnt.wrapping_add(payload.len() as u32); + + // Check if we should proactively send a credit update. + // If the peer's remaining view of our buffer is below half, signal update. + let peer_view = self.buf_alloc.saturating_sub( + self.fwd_cnt + .wrapping_sub(/* they don't know fwd_cnt yet */ 0), + ); + if peer_view < self.buf_alloc / 2 { + self.credit_update_needed = true; + } + + Some(payload.to_vec()) + } + + /// Handle a SHUTDOWN from the guest. + pub fn handle_shutdown(&mut self, flags: u32) { + let old_state = self.state; + match self.state { + ConnState::Connected => { + if flags & (VSOCK_SHUTDOWN_SEND | VSOCK_SHUTDOWN_RECV) + == (VSOCK_SHUTDOWN_SEND | VSOCK_SHUTDOWN_RECV) + { + self.state = ConnState::Closed; + } else { + self.state = ConnState::Closing; + } + } + ConnState::Closing => { + self.state = ConnState::Closed; + } + _ => {} + } + if self.state != old_state { + log::debug!( + "vsock conn ({},{}) → {:?} (SHUTDOWN flags=0x{:x})", + self.local_port, + self.peer_port, + self.state, + flags + ); + } + } + + /// Handle a RST from the guest. + pub fn handle_rst(&mut self) { + log::debug!( + "vsock conn ({},{}) → Closed (RST)", + self.local_port, + self.peer_port + ); + self.state = ConnState::Closed; + } + + /// Handle a credit update from the guest. + pub fn handle_credit_update(&mut self, hdr: &VsockHeader) { + self.peer_buf_alloc = hdr.buf_alloc; + self.peer_fwd_cnt = hdr.fwd_cnt; + } + + /// Handle a credit request from the guest. + pub fn handle_credit_request(&mut self) { + self.credit_update_needed = true; + } + + /// Enqueue data from the host for transmission to the guest. + /// + /// Returns the number of bytes actually enqueued (limited by peer credit). + pub fn enqueue_tx(&mut self, data: &[u8]) -> usize { + if self.state != ConnState::Connected { + return 0; + } + + let credit = self.peer_credit() as usize; + let to_send = data.len().min(credit); + if to_send > 0 { + self.tx_buf.extend_from_slice(&data[..to_send]); + } + to_send + } + + /// Drain pending host-to-guest data, limited by available credit. + /// + /// Returns data to be placed in an RX virtqueue buffer, along with + /// the header to prepend. + pub fn drain_tx(&mut self, max_payload: usize) -> Option<(VsockHeader, Vec)> { + if self.tx_buf.is_empty() { + return None; + } + + let send_len = self.tx_buf.len().min(max_payload); + let data: Vec = self.tx_buf.drain(..send_len).collect(); + + self.tx_cnt = self.tx_cnt.wrapping_add(data.len() as u32); + + let hdr = VsockHeader::new_rw( + self.local_cid, + self.local_port, + self.peer_cid, + self.peer_port, + data.len() as u32, + self.buf_alloc, + self.fwd_cnt, + ); + + Some((hdr, data)) + } + + /// Build a credit update header for this connection. + pub fn make_credit_update(&self) -> VsockHeader { + VsockHeader::new_credit_update( + self.local_cid, + self.local_port, + self.peer_cid, + self.peer_port, + self.buf_alloc, + self.fwd_cnt, + ) + } + + /// Build a RST header for this connection. + pub fn make_rst(&self) -> VsockHeader { + VsockHeader::new_rst( + self.local_cid, + self.local_port, + self.peer_cid, + self.peer_port, + ) + } + + /// Dispatch a packet by operation code. + /// + /// Returns a response header to send back (if any), and optional + /// payload data to forward to the host side. + pub fn dispatch( + &mut self, + hdr: &VsockHeader, + payload: &[u8], + ) -> (Option, Option>) { + match hdr.op { + VSOCK_OP_REQUEST => { + let resp = self.handle_request(hdr); + if resp.is_none() { + // Already connected or invalid state -> RST. + return (Some(self.make_rst()), None); + } + (resp, None) + } + VSOCK_OP_RW => { + let data = self.handle_rw(payload); + let credit_hdr = if self.credit_update_needed { + self.credit_update_needed = false; + Some(self.make_credit_update()) + } else { + None + }; + (credit_hdr, data) + } + VSOCK_OP_SHUTDOWN => { + self.handle_shutdown(hdr.flags); + (None, None) + } + VSOCK_OP_RST => { + self.handle_rst(); + (None, None) + } + VSOCK_OP_RESPONSE => { + // Guest accepted our connection (host-initiated connect). + if self.state == ConnState::Connecting { + self.peer_buf_alloc = hdr.buf_alloc; + self.peer_fwd_cnt = hdr.fwd_cnt; + self.state = ConnState::Connected; + log::debug!( + "vsock conn ({},{}) → Connected (guest RESPONSE, buf_alloc={})", + self.local_port, + self.peer_port, + hdr.buf_alloc + ); + } + (None, None) + } + VSOCK_OP_CREDIT_UPDATE => { + self.handle_credit_update(hdr); + (None, None) + } + VSOCK_OP_CREDIT_REQUEST => { + self.handle_credit_request(); + let update = self.make_credit_update(); + self.credit_update_needed = false; + (Some(update), None) + } + _ => { + // Unknown op -> RST. + (Some(self.make_rst()), None) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn guest_conn() -> VsockConnection { + // local = host (CID 2), peer = guest (CID 3) + VsockConnection::new(2, 2695, 3, 5000) + } + + fn make_request_hdr() -> VsockHeader { + VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + } + } + + // --- State transitions --- + + #[test] + fn test_new_connection_is_idle() { + let conn = guest_conn(); + assert_eq!(conn.state(), ConnState::Idle); + } + + #[test] + fn test_handle_request_transitions_to_connected() { + let mut conn = guest_conn(); + let hdr = make_request_hdr(); + let resp = conn.handle_request(&hdr); + assert!(resp.is_some()); + assert_eq!(conn.state(), ConnState::Connected); + + let r = resp.unwrap(); + assert_eq!(r.op, VSOCK_OP_RESPONSE); + assert_eq!(r.src_cid, 2); + assert_eq!(r.dst_cid, 3); + assert_eq!(r.buf_alloc, DEFAULT_BUF_ALLOC); + } + + // --- Host-initiated connection --- + + #[test] + fn test_initiate_connect_transitions_to_connecting() { + let mut conn = guest_conn(); + let req = conn.initiate_connect(); + assert!(req.is_some()); + assert_eq!(conn.state(), ConnState::Connecting); + + let r = req.unwrap(); + assert_eq!(r.op, VSOCK_OP_REQUEST); + assert_eq!(r.src_cid, 2); // HOST + assert_eq!(r.dst_cid, 3); // guest + assert_eq!(r.src_port, 2695); // local_port + assert_eq!(r.dst_port, 5000); // peer_port + assert_eq!(r.buf_alloc, DEFAULT_BUF_ALLOC); + } + + #[test] + fn test_initiate_connect_on_non_idle_returns_none() { + let mut conn = guest_conn(); + conn.initiate_connect(); + // Second call should fail (already Connecting). + assert!(conn.initiate_connect().is_none()); + } + + #[test] + fn test_response_transitions_connecting_to_connected() { + let mut conn = guest_conn(); + conn.initiate_connect(); + assert_eq!(conn.state(), ConnState::Connecting); + + let resp = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_RESPONSE, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let (hdr, data) = conn.dispatch(&resp, &[]); + assert!(hdr.is_none()); // No response to a RESPONSE + assert!(data.is_none()); + assert_eq!(conn.state(), ConnState::Connected); + assert_eq!(conn.peer_credit(), 32768); + } + + #[test] + fn test_request_on_non_idle_returns_none() { + let mut conn = guest_conn(); + let hdr = make_request_hdr(); + conn.handle_request(&hdr); + assert_eq!(conn.state(), ConnState::Connected); + + // Second request should fail. + let resp = conn.handle_request(&hdr); + assert!(resp.is_none()); + } + + #[test] + fn test_shutdown_both_transitions_to_closed() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_shutdown(VSOCK_SHUTDOWN_SEND | VSOCK_SHUTDOWN_RECV); + assert_eq!(conn.state(), ConnState::Closed); + } + + #[test] + fn test_shutdown_send_only_transitions_to_closing() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_shutdown(VSOCK_SHUTDOWN_SEND); + assert_eq!(conn.state(), ConnState::Closing); + } + + #[test] + fn test_shutdown_closing_to_closed() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_shutdown(VSOCK_SHUTDOWN_SEND); + assert_eq!(conn.state(), ConnState::Closing); + conn.handle_shutdown(VSOCK_SHUTDOWN_RECV); + assert_eq!(conn.state(), ConnState::Closed); + } + + #[test] + fn test_rst_transitions_to_closed() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_rst(); + assert_eq!(conn.state(), ConnState::Closed); + } + + // --- Data transfer --- + + #[test] + fn test_handle_rw_returns_data() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + let data = conn.handle_rw(b"hello"); + assert_eq!(data.as_deref(), Some(b"hello".as_slice())); + } + + #[test] + fn test_handle_rw_updates_fwd_cnt() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_rw(b"hello"); // 5 bytes + assert_eq!(conn.fwd_cnt(), 5); + conn.handle_rw(b"world!"); // 6 bytes + assert_eq!(conn.fwd_cnt(), 11); + } + + #[test] + fn test_handle_rw_when_not_connected_returns_none() { + let mut conn = guest_conn(); + let data = conn.handle_rw(b"hello"); + assert!(data.is_none()); + } + + // --- Credit flow control --- + + #[test] + fn test_peer_credit_initial() { + let mut conn = guest_conn(); + let hdr = make_request_hdr(); // peer_buf_alloc = 32768 + conn.handle_request(&hdr); + assert_eq!(conn.peer_credit(), 32768); + } + + #[test] + fn test_peer_credit_decreases_with_tx() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.enqueue_tx(&[0u8; 1000]); + conn.drain_tx(1000); + // tx_cnt = 1000, peer_fwd_cnt = 0 -> credit = 32768 - 1000 = 31768 + assert_eq!(conn.peer_credit(), 31768); + } + + #[test] + fn test_peer_credit_recovers_with_update() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.enqueue_tx(&[0u8; 1000]); + conn.drain_tx(1000); + + // Simulate peer consumed 1000 bytes. + let mut update = make_request_hdr(); + update.op = VSOCK_OP_CREDIT_UPDATE; + update.fwd_cnt = 1000; + update.buf_alloc = 32768; + conn.handle_credit_update(&update); + + assert_eq!(conn.peer_credit(), 32768); + } + + #[test] + fn test_enqueue_tx_respects_credit() { + let mut conn = guest_conn(); + let mut hdr = make_request_hdr(); + hdr.buf_alloc = 100; // Only 100 bytes of credit. + conn.handle_request(&hdr); + + let queued = conn.enqueue_tx(&[0xAA; 200]); + assert_eq!(queued, 100); // Limited by credit. + assert_eq!(conn.tx_buf_len(), 100); + } + + #[test] + fn test_enqueue_tx_when_not_connected() { + let mut conn = guest_conn(); + let queued = conn.enqueue_tx(b"hello"); + assert_eq!(queued, 0); + } + + // --- Drain TX --- + + #[test] + fn test_drain_tx_returns_data_and_header() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.enqueue_tx(b"hello"); + + let (hdr, data) = conn.drain_tx(1024).unwrap(); + assert_eq!(data, b"hello"); + assert_eq!(hdr.op, VSOCK_OP_RW); + assert_eq!(hdr.len, 5); + assert_eq!(hdr.src_cid, 2); + assert_eq!(hdr.dst_cid, 3); + } + + #[test] + fn test_drain_tx_respects_max_payload() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.enqueue_tx(&[0xBB; 1000]); + + let (hdr, data) = conn.drain_tx(500).unwrap(); + assert_eq!(data.len(), 500); + assert_eq!(hdr.len, 500); + + // Remaining data still in buffer. + assert_eq!(conn.tx_buf_len(), 500); + } + + #[test] + fn test_drain_tx_empty_returns_none() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + assert!(conn.drain_tx(1024).is_none()); + } + + #[test] + fn test_drain_tx_updates_tx_cnt() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.enqueue_tx(b"12345"); + conn.drain_tx(1024); + assert_eq!(conn.tx_cnt(), 5); + } + + // --- Credit request --- + + #[test] + fn test_credit_request_sets_flag() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + assert!(!conn.needs_credit_update()); + conn.handle_credit_request(); + assert!(conn.needs_credit_update()); + } + + #[test] + fn test_clear_credit_update() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_credit_request(); + assert!(conn.needs_credit_update()); + conn.clear_credit_update(); + assert!(!conn.needs_credit_update()); + } + + // --- Dispatch --- + + #[test] + fn test_dispatch_request() { + let mut conn = guest_conn(); + let hdr = make_request_hdr(); + let (resp, data) = conn.dispatch(&hdr, &[]); + assert!(resp.is_some()); + assert_eq!(resp.unwrap().op, VSOCK_OP_RESPONSE); + assert!(data.is_none()); + assert_eq!(conn.state(), ConnState::Connected); + } + + #[test] + fn test_dispatch_rw() { + let mut conn = guest_conn(); + conn.dispatch(&make_request_hdr(), &[]); + + let rw_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 5, + type_: 1, + op: VSOCK_OP_RW, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let (_, data) = conn.dispatch(&rw_hdr, b"hello"); + assert_eq!(data.as_deref(), Some(b"hello".as_slice())); + } + + #[test] + fn test_dispatch_credit_request_sends_update() { + let mut conn = guest_conn(); + conn.dispatch(&make_request_hdr(), &[]); + + let cr_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_CREDIT_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let (resp, _) = conn.dispatch(&cr_hdr, &[]); + assert!(resp.is_some()); + assert_eq!(resp.unwrap().op, VSOCK_OP_CREDIT_UPDATE); + } + + #[test] + fn test_dispatch_unknown_op_sends_rst() { + let mut conn = guest_conn(); + conn.dispatch(&make_request_hdr(), &[]); + + let bad_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: 99, + flags: 0, + buf_alloc: 0, + fwd_cnt: 0, + }; + let (resp, _) = conn.dispatch(&bad_hdr, &[]); + assert!(resp.is_some()); + assert_eq!(resp.unwrap().op, VSOCK_OP_RST); + } + + #[test] + fn test_dispatch_request_on_connected_sends_rst() { + let mut conn = guest_conn(); + conn.dispatch(&make_request_hdr(), &[]); + // Second REQUEST while connected. + let (resp, _) = conn.dispatch(&make_request_hdr(), &[]); + assert!(resp.is_some()); + assert_eq!(resp.unwrap().op, VSOCK_OP_RST); + } + + // --- Make helpers --- + + #[test] + fn test_make_credit_update() { + let mut conn = guest_conn(); + conn.handle_request(&make_request_hdr()); + conn.handle_rw(b"hello"); // fwd_cnt = 5 + let hdr = conn.make_credit_update(); + assert_eq!(hdr.op, VSOCK_OP_CREDIT_UPDATE); + assert_eq!(hdr.fwd_cnt, 5); + assert_eq!(hdr.buf_alloc, DEFAULT_BUF_ALLOC); + } + + #[test] + fn test_make_rst() { + let conn = guest_conn(); + let hdr = conn.make_rst(); + assert_eq!(hdr.op, VSOCK_OP_RST); + assert_eq!(hdr.src_cid, 2); + assert_eq!(hdr.dst_cid, 3); + } +} diff --git a/src/vmm/src/windows/devices/virtio/vsock/mod.rs b/src/vmm/src/windows/devices/virtio/vsock/mod.rs new file mode 100644 index 000000000..929453f9f --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/vsock/mod.rs @@ -0,0 +1,1729 @@ +//! Virtio-vsock device backend (virtio spec v1.2 Section 5.10). +//! +//! Provides a socket transport between guest (AF_VSOCK) and host (Unix sockets). +//! The host side uses non-blocking Unix domain socket listeners for +//! cross-platform compatibility (Windows + macOS + Linux). +//! +//! Queue layout: +//! Queue 0 (RX): host -> guest (device writes, guest reads) +//! Queue 1 (TX): guest -> host (guest writes, device reads) +//! Queue 2 (Event): device events (not used currently) + +pub mod connection; +pub mod packet; + +use std::collections::HashMap; +use std::io::{self, Read, Write}; + +#[cfg(unix)] +use std::os::unix::net::{UnixListener, UnixStream}; +#[cfg(windows)] +use uds_windows::{UnixListener, UnixStream}; + +use super::mmio::VirtioDeviceBackend; +use super::queue::{GuestMemoryAccessor, Virtqueue}; +use connection::{ConnState, VsockConnection}; +use packet::{VsockHeader, VSOCK_CID_HOST, VSOCK_HEADER_SIZE, VSOCK_OP_REQUEST}; + +/// Virtio device ID for vsock (spec Section 5.10). +const VIRTIO_VSOCK_ID: u32 = 19; + +/// VIRTIO_F_VERSION_1 — bit 32 (page 1, bit 0). +const VIRTIO_F_VERSION_1_BIT: u32 = 0; + +/// Number of queues: RX, TX, Event. +const NUM_QUEUES: usize = 3; + +/// Queue index constants. +const RX_QUEUE: usize = 0; +const TX_QUEUE: usize = 1; +// const EVENT_QUEUE: usize = 2; // Not used yet. + +/// Maximum queue size. +const QUEUE_MAX_SIZE: u16 = 128; + +/// Connection key: (guest_port, host_port). +type ConnKey = (u32, u32); + +/// Starting ephemeral port for host-initiated vsock connections. +const EPHEMERAL_PORT_START: u32 = 49152; + +/// Virtio-vsock device with Unix socket host-side bridge. +pub struct VirtioVsock { + /// Guest CID (typically 3 for the first guest). + guest_cid: u64, + /// Active connections keyed by (guest_port, host_port). + connections: HashMap, + /// Unix socket listeners on the host side, keyed by vsock port. + /// Used for host-initiated connections (host UDS → guest vsock). + listeners: HashMap, + /// Outbound Unix socket targets keyed by vsock port. + /// Used for guest-initiated connections (guest vsock → host UDS). + /// When the guest connects to a port in this map, the device makes + /// an outbound Unix socket connection to the specified path. + connect_targets: HashMap, + /// Accepted Unix streams, keyed by (guest_port, host_port). + streams: HashMap, + /// Pending response/control packets to inject into the RX queue. + rx_pending: Vec<(VsockHeader, Vec)>, + /// Next ephemeral port for host-initiated connections. + next_host_port: u32, +} + +impl VirtioVsock { + /// Create a new vsock device with the given guest CID. + pub fn new(guest_cid: u64) -> Self { + VirtioVsock { + guest_cid, + connections: HashMap::new(), + listeners: HashMap::new(), + connect_targets: HashMap::new(), + streams: HashMap::new(), + rx_pending: Vec::new(), + next_host_port: EPHEMERAL_PORT_START, + } + } + + /// Register a Unix socket listener on `socket_path` for the given vsock port. + /// + /// When a guest connects to this port via AF_VSOCK, the connection + /// is bridged to an accepted Unix socket client on this listener. + /// + /// Removes any stale socket file before binding. + pub fn listen_on(&mut self, vsock_port: u32, socket_path: &str) -> io::Result<()> { + // Remove stale socket file if it exists. + let _ = std::fs::remove_file(socket_path); + let listener = UnixListener::bind(socket_path)?; + listener.set_nonblocking(true)?; + self.listeners.insert(vsock_port, listener); + Ok(()) + } + + /// Register an outbound Unix socket target for guest-initiated connections. + /// + /// When the guest connects to `vsock_port`, the device makes an outbound + /// Unix socket connection to `host_path` instead of accepting from a listener. + /// Used for notification channels where the guest initiates the connection + /// and the host is already listening. + pub fn connect_to(&mut self, vsock_port: u32, host_path: String) { + self.connect_targets.insert(vsock_port, host_path); + } + + /// Get the guest CID. + pub fn guest_cid(&self) -> u64 { + self.guest_cid + } + + /// Number of active connections. + pub fn connection_count(&self) -> usize { + self.connections.len() + } + + /// Process the TX queue: read packets from guest, dispatch them. + fn process_tx(&mut self, queue: &mut Virtqueue, mem: &dyn GuestMemoryAccessor) -> bool { + let mut processed = false; + + while let Ok(Some(head)) = queue.pop_avail(mem) { + let chain = match queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + }; + + if chain.is_empty() { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + + // First descriptor: vsock header (device-readable). + let hdr_desc = &chain[0]; + if (hdr_desc.len as usize) < VSOCK_HEADER_SIZE { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + + let hdr = match VsockHeader::read_from(mem, hdr_desc.addr) { + Ok(h) => h, + Err(_) => { + let _ = queue.add_used(head, 0, mem); + processed = true; + continue; + } + }; + + // Read payload from subsequent descriptors. + let mut payload = Vec::new(); + for desc in &chain[1..] { + if !desc.is_write() { + // Device-readable = payload data from guest. + let mut buf = vec![0u8; desc.len as usize]; + if mem.read_at(desc.addr, &mut buf).is_ok() { + payload.extend_from_slice(&buf); + } + } + } + + self.handle_guest_packet(&hdr, &payload); + + let _ = queue.add_used(head, 0, mem); + processed = true; + } + + processed + } + + /// Handle a packet from the guest. + fn handle_guest_packet(&mut self, hdr: &VsockHeader, payload: &[u8]) { + let key = (hdr.src_port, hdr.dst_port); + if !payload.is_empty() { + log::trace!( + "vsock TX: guest→host {} bytes, op={}, key=({},{})", + payload.len(), + hdr.op, + key.0, + key.1 + ); + } + + if hdr.op == VSOCK_OP_REQUEST { + self.handle_connect_request(hdr); + return; + } + + if let Some(conn) = self.connections.get_mut(&key) { + let (resp_hdr, fwd_data) = conn.dispatch(hdr, payload); + + // Forward data to host Unix socket. + // Use retry loop for non-blocking sockets (write_all fails on WouldBlock). + if let Some(data) = fwd_data { + if let Some(stream) = self.streams.get_mut(&key) { + let mut written = 0; + let mut retries = 0; + while written < data.len() { + match stream.write(&data[written..]) { + Ok(n) => written += n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + retries += 1; + if retries > 1000 { + log::warn!( + "vsock write stuck: {}/{} bytes after {} retries, key=({},{})", + written, data.len(), retries, key.0, key.1 + ); + break; + } + std::thread::yield_now(); + } + Err(e) => { + log::warn!( + "vsock write failed: {}/{} bytes, err={}, key=({},{})", + written, + data.len(), + e, + key.0, + key.1 + ); + break; + } + } + } + } + } + + // Queue response packet (if any) for RX injection. + if let Some(r) = resp_hdr { + self.rx_pending.push((r, Vec::new())); + } + + // Clean up closed connections. + if conn.state() == ConnState::Closed { + self.connections.remove(&key); + self.streams.remove(&key); + } + } else { + // No connection for this port pair -> RST. + let rst = + VsockHeader::new_rst(VSOCK_CID_HOST, hdr.dst_port, self.guest_cid, hdr.src_port); + self.rx_pending.push((rst, Vec::new())); + } + } + + /// Handle a guest CONNECTION REQUEST. + fn handle_connect_request(&mut self, hdr: &VsockHeader) { + let key = (hdr.src_port, hdr.dst_port); + + // Try outbound connection first (guest-initiated → host UDS target). + if let Some(path) = self.connect_targets.get(&hdr.dst_port).cloned() { + log::debug!("guest-initiated CONNECT: port={} → {}", hdr.dst_port, path); + let stream = match UnixStream::connect(&path) { + Ok(stream) => { + if let Err(e) = stream.set_nonblocking(true) { + log::warn!("guest-connect: set_nonblocking failed: {}", e); + } + log::debug!("UDS connect OK to {}", path); + stream + } + Err(ref e) => { + log::warn!("UDS connect FAILED to {}: {}", path, e); + let rst = VsockHeader::new_rst( + VSOCK_CID_HOST, + hdr.dst_port, + self.guest_cid, + hdr.src_port, + ); + self.rx_pending.push((rst, Vec::new())); + return; + } + }; + + let mut conn = + VsockConnection::new(VSOCK_CID_HOST, hdr.dst_port, self.guest_cid, hdr.src_port); + + if let Some(resp) = conn.handle_request(hdr) { + self.rx_pending.push((resp, Vec::new())); + self.connections.insert(key, conn); + self.streams.insert(key, stream); + } else { + let rst = VsockHeader::new_rst( + VSOCK_CID_HOST, + hdr.dst_port, + self.guest_cid, + hdr.src_port, + ); + self.rx_pending.push((rst, Vec::new())); + } + return; + } + + // Fall back to listener-based connection (host-initiated). + if !self.listeners.contains_key(&hdr.dst_port) { + let rst = + VsockHeader::new_rst(VSOCK_CID_HOST, hdr.dst_port, self.guest_cid, hdr.src_port); + self.rx_pending.push((rst, Vec::new())); + return; + } + + // Try to accept a pending Unix socket connection on this listener. + let stream = if let Some(listener) = self.listeners.get(&hdr.dst_port) { + match listener.accept() { + Ok((stream, _addr)) => { + let _ = stream.set_nonblocking(true); + Some(stream) + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // No pending connection yet — still accept the vsock connection. + // Data will buffer until a client connects. + None + } + Err(_) => { + let rst = VsockHeader::new_rst( + VSOCK_CID_HOST, + hdr.dst_port, + self.guest_cid, + hdr.src_port, + ); + self.rx_pending.push((rst, Vec::new())); + return; + } + } + } else { + None + }; + + // Create and register the connection. + let mut conn = + VsockConnection::new(VSOCK_CID_HOST, hdr.dst_port, self.guest_cid, hdr.src_port); + + if let Some(resp) = conn.handle_request(hdr) { + self.rx_pending.push((resp, Vec::new())); + self.connections.insert(key, conn); + if let Some(s) = stream { + self.streams.insert(key, s); + } + } else { + let rst = + VsockHeader::new_rst(VSOCK_CID_HOST, hdr.dst_port, self.guest_cid, hdr.src_port); + self.rx_pending.push((rst, Vec::new())); + } + } + + /// Allocate the next ephemeral host port for host-initiated connections. + fn alloc_host_port(&mut self) -> u32 { + let port = self.next_host_port; + self.next_host_port = self.next_host_port.wrapping_add(1); + if self.next_host_port < EPHEMERAL_PORT_START { + self.next_host_port = EPHEMERAL_PORT_START; + } + port + } + + /// Poll Unix socket listeners for pending connections and initiate vsock handshakes. + /// + /// When a host client connects to a listener, this method: + /// 1. Accepts the Unix socket connection + /// 2. Allocates an ephemeral host port for the vsock side + /// 3. Creates a VsockConnection in Connecting state + /// 4. Generates a REQUEST packet to send to the guest via RX queue + /// 5. Stores the Unix stream (data is NOT read until Connected) + fn poll_listeners(&mut self) { + let vsock_ports: Vec = self.listeners.keys().copied().collect(); + + for vsock_port in vsock_ports { + let stream = if let Some(listener) = self.listeners.get(&vsock_port) { + match listener.accept() { + Ok((stream, _addr)) => { + if let Err(e) = stream.set_nonblocking(true) { + log::warn!("vsock set_nonblocking failed: {}", e); + } + stream + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(_) => continue, + } + } else { + continue; + }; + + let host_port = self.alloc_host_port(); + let key = (vsock_port, host_port); + + let mut conn = + VsockConnection::new(VSOCK_CID_HOST, host_port, self.guest_cid, vsock_port); + + if let Some(req) = conn.initiate_connect() { + log::debug!( + "host-initiated CONNECT: vsock_port={}, host_port={}, queuing REQUEST", + vsock_port, + host_port + ); + self.rx_pending.push((req, Vec::new())); + self.connections.insert(key, conn); + self.streams.insert(key, stream); + } + } + } + + /// Poll Unix streams for incoming data and queue it for RX injection. + fn poll_streams(&mut self) { + // Collect keys first to avoid borrow issues. + let keys: Vec = self.streams.keys().copied().collect(); + + for key in keys { + // Skip streams whose vsock connection is still handshaking. + // Data stays in the kernel receive buffer until Connected. + if let Some(conn) = self.connections.get(&key) { + if conn.state() != ConnState::Connected { + continue; + } + } + + let mut buf = [0u8; 65536]; + let data = if let Some(stream) = self.streams.get_mut(&key) { + match stream.read(&mut buf) { + Ok(0) => { + // Unix socket connection closed. Send SHUTDOWN to guest. + log::debug!("UDS EOF, key=({},{})", key.0, key.1); + if let Some(conn) = self.connections.get(&key) { + let hdr = VsockHeader::new_shutdown( + conn.local_cid, + conn.local_port, + conn.peer_cid, + conn.peer_port, + packet::VSOCK_SHUTDOWN_SEND | packet::VSOCK_SHUTDOWN_RECV, + ); + self.rx_pending.push((hdr, Vec::new())); + } + self.streams.remove(&key); + if let Some(conn) = self.connections.get_mut(&key) { + conn.handle_shutdown( + packet::VSOCK_SHUTDOWN_SEND | packet::VSOCK_SHUTDOWN_RECV, + ); + } + continue; + } + Ok(n) => { + log::trace!("UDS read {} bytes, key=({},{})", n, key.0, key.1); + Some(buf[..n].to_vec()) + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => None, + Err(ref e) => { + // I/O error on Unix stream. RST the vsock connection. + log::warn!( + "vsock UDS read error: {} (raw={:?}), key=({},{})", + e, + e.raw_os_error(), + key.0, + key.1 + ); + if let Some(conn) = self.connections.get(&key) { + let rst = conn.make_rst(); + self.rx_pending.push((rst, Vec::new())); + } + self.streams.remove(&key); + self.connections.remove(&key); + continue; + } + } + } else { + continue; + }; + + // Enqueue data from TCP into the connection's TX buffer. + if let Some(data) = data { + if let Some(conn) = self.connections.get_mut(&key) { + let enqueued = conn.enqueue_tx(&data); + if enqueued < data.len() { + log::debug!( + "vsock enqueue_tx partial: {}/{} bytes, credit={}, key=({},{})", + enqueued, + data.len(), + conn.peer_credit(), + key.0, + key.1 + ); + } + } + } + } + } + + /// Inject pending packets into the RX queue. + fn inject_rx(&mut self, rx_queue: &mut Virtqueue, mem: &dyn GuestMemoryAccessor) -> bool { + let mut injected = false; + + // First: drain connection TX buffers into rx_pending. + let keys: Vec = self.connections.keys().copied().collect(); + for key in keys { + if let Some(conn) = self.connections.get_mut(&key) { + // Also check for credit updates. + if conn.needs_credit_update() { + let hdr = conn.make_credit_update(); + conn.clear_credit_update(); + self.rx_pending.push((hdr, Vec::new())); + } + + // Drain TX data. + while let Some((hdr, data)) = conn.drain_tx(4096) { + self.rx_pending.push((hdr, data)); + } + } + } + + // Inject all pending packets. + while !self.rx_pending.is_empty() { + let head = match rx_queue.pop_avail(mem) { + Ok(Some(h)) => h, + _ => { + log::debug!( + "vsock inject_rx: no available RX buffers, {} packets pending", + self.rx_pending.len() + ); + break; + } + }; + + let chain = match rx_queue.read_desc_chain(head, mem) { + Ok(c) => c, + Err(_) => { + let _ = rx_queue.add_used(head, 0, mem); + injected = true; + continue; + } + }; + + let (hdr, payload) = self.rx_pending.remove(0); + + // Write header + payload to device-writable descriptors. + let total_data = hdr + .to_bytes() + .to_vec() + .into_iter() + .chain(payload.into_iter()) + .collect::>(); + + let mut offset = 0; + let mut total_written = 0u32; + for desc in &chain { + if !desc.is_write() { + continue; + } + let remaining = total_data.len().saturating_sub(offset); + let to_write = remaining.min(desc.len as usize); + if to_write > 0 { + let _ = mem.write_at(desc.addr, &total_data[offset..offset + to_write]); + offset += to_write; + total_written += to_write as u32; + } + } + + let _ = rx_queue.add_used(head, total_written, mem); + injected = true; + } + + injected + } +} + +impl VirtioDeviceBackend for VirtioVsock { + fn device_id(&self) -> u32 { + VIRTIO_VSOCK_ID + } + + fn device_features(&self, page: u32) -> u32 { + match page { + 1 => 1 << VIRTIO_F_VERSION_1_BIT, + _ => 0, + } + } + + fn read_config(&self, offset: u64) -> u32 { + // Config space: guest_cid (u64 at offset 0). + match offset { + 0 => self.guest_cid as u32, + 4 => (self.guest_cid >> 32) as u32, + _ => 0, + } + } + + fn queue_notify( + &mut self, + queue_idx: u32, + queue: &mut Virtqueue, + mem: &dyn GuestMemoryAccessor, + ) -> bool { + match queue_idx as usize { + TX_QUEUE => self.process_tx(queue, mem), + _ => false, + } + } + + fn num_queues(&self) -> usize { + NUM_QUEUES + } + + fn queue_max_size(&self, _queue_idx: u32) -> u16 { + QUEUE_MAX_SIZE + } + + fn poll(&mut self, queues: &mut [Virtqueue], mem: &dyn GuestMemoryAccessor) -> bool { + // Accept new Unix socket connections and initiate vsock handshakes. + self.poll_listeners(); + + // Poll Unix streams for incoming data. + let pending_before = self.rx_pending.len(); + self.poll_streams(); + let new_data = self.rx_pending.len() - pending_before; + if new_data > 0 { + log::trace!( + "vsock poll: UDS produced {} new packets, total pending={}", + new_data, + self.rx_pending.len() + ); + } + + // Inject any pending data into the RX queue. + if queues.len() > RX_QUEUE { + let injected = self.inject_rx(&mut queues[RX_QUEUE], mem); + if injected { + log::debug!( + "vsock poll: injected data into RX queue, conns={}", + self.connections.len() + ); + } + injected + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::super::super::super::error::Result; + use super::super::queue::Virtqueue; + use super::packet::VSOCK_OP_RST; + use super::*; + use std::cell::RefCell; + + struct MockMem { + data: RefCell>, + } + + impl MockMem { + fn new(size: usize) -> Self { + MockMem { + data: RefCell::new(vec![0u8; size]), + } + } + + fn write_bytes(&self, addr: u64, bytes: &[u8]) { + let a = addr as usize; + let mut data = self.data.borrow_mut(); + data[a..a + bytes.len()].copy_from_slice(bytes); + } + + fn read_bytes(&self, addr: u64, len: usize) -> Vec { + let a = addr as usize; + let data = self.data.borrow(); + data[a..a + len].to_vec() + } + + fn write_u16_at(&self, addr: u64, val: u16) { + self.write_bytes(addr, &val.to_le_bytes()); + } + + fn write_u32_at(&self, addr: u64, val: u32) { + self.write_bytes(addr, &val.to_le_bytes()); + } + + fn write_u64_at(&self, addr: u64, val: u64) { + self.write_bytes(addr, &val.to_le_bytes()); + } + } + + impl GuestMemoryAccessor for MockMem { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + let a = addr as usize; + let data = self.data.borrow(); + if a + buf.len() > data.len() { + return Err(super::super::super::super::error::WkrunError::Memory( + "out of bounds".into(), + )); + } + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + let a = addr as usize; + let mut mem = self.data.borrow_mut(); + if a + data.len() > mem.len() { + return Err(super::super::super::super::error::WkrunError::Memory( + "out of bounds".into(), + )); + } + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + // Memory layout for tests: + // DESC_TABLE at 0x0000 (128 entries * 16 bytes = 2048) + // AVAIL_RING at 0x0800 + // USED_RING at 0x1000 + // BUFFERS at 0x2000+ + const DESC_TABLE: u64 = 0x0000; + const DESC_SIZE: u64 = 16; + const AVAIL_RING: u64 = 0x0800; + const USED_RING: u64 = 0x1000; + const BUF_BASE: u64 = 0x2000; + + fn setup_queue(max_size: u16) -> Virtqueue { + let mut q = Virtqueue::new(max_size); + q.set_size(max_size); + q.set_desc_table(DESC_TABLE); + q.set_avail_ring(AVAIL_RING); + q.set_used_ring(USED_RING); + q.set_ready(true); + q + } + + fn write_descriptor(mem: &MockMem, index: u16, addr: u64, len: u32, flags: u16, next: u16) { + let base = DESC_TABLE + index as u64 * DESC_SIZE; + mem.write_u64_at(base, addr); + mem.write_u32_at(base + 8, len); + mem.write_u16_at(base + 12, flags); + mem.write_u16_at(base + 14, next); + } + + fn push_avail(mem: &MockMem, ring_idx: u16, desc_head: u16) { + let entry_off = AVAIL_RING + 4 + (ring_idx as u64) * 2; + mem.write_u16_at(entry_off, desc_head); + mem.write_u16_at(AVAIL_RING + 2, ring_idx + 1); + } + + // --- Device identity --- + + #[test] + fn test_device_id() { + let dev = VirtioVsock::new(3); + assert_eq!(dev.device_id(), 19); + } + + #[test] + fn test_num_queues() { + let dev = VirtioVsock::new(3); + assert_eq!(dev.num_queues(), 3); + } + + #[test] + fn test_queue_max_size() { + let dev = VirtioVsock::new(3); + assert_eq!(dev.queue_max_size(0), 128); + assert_eq!(dev.queue_max_size(1), 128); + assert_eq!(dev.queue_max_size(2), 128); + } + + #[test] + fn test_version_1_feature() { + let dev = VirtioVsock::new(3); + assert_eq!(dev.device_features(0), 0); + assert_eq!(dev.device_features(1), 1); // VIRTIO_F_VERSION_1 + } + + // --- Config space --- + + #[test] + fn test_config_guest_cid() { + let dev = VirtioVsock::new(3); + assert_eq!(dev.read_config(0), 3); // Low 32 bits. + assert_eq!(dev.read_config(4), 0); // High 32 bits. + } + + #[test] + fn test_config_large_cid() { + let dev = VirtioVsock::new(0x1_0000_0003); + assert_eq!(dev.read_config(0), 3); + assert_eq!(dev.read_config(4), 1); + } + + // --- TX queue: REQUEST handling --- + + #[test] + fn test_tx_request_no_listener_sends_rst() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + // Write a REQUEST header to guest memory. + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + + // Set up descriptor: header only. + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + + // Should have a RST pending in rx_pending. + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, VSOCK_OP_RST); + } + + /// Create a temporary socket path for tests. + fn temp_socket_path(name: &str) -> (std::path::PathBuf, tempfile::TempDir) { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join(name); + (path, dir) + } + + #[test] + fn test_tx_request_with_listener_sends_response() { + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("vsock-test.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: vsock_port, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + // Should have a RESPONSE pending. + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, packet::VSOCK_OP_RESPONSE); + assert_eq!(dev.connection_count(), 1); + } + + // --- TX queue: RW handling --- + + #[test] + fn test_tx_rw_forwards_data() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + // Establish connection directly. + let req_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + dev.handle_guest_packet(&req_hdr, &[]); + // The REQUEST without a listener sends RST, so let's set up directly. + dev.rx_pending.clear(); + + // Manually create a connected state. + let mut conn = VsockConnection::new(VSOCK_CID_HOST, 2695, 3, 5000); + conn.handle_request(&req_hdr); + dev.connections.insert((5000, 2695), conn); + + // Now send an RW packet. + let rw_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 5, + type_: 1, + op: packet::VSOCK_OP_RW, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &rw_hdr.to_bytes()); + mem.write_bytes(BUF_BASE + VSOCK_HEADER_SIZE as u64, b"hello"); + + // Two descriptors: header (readable) + payload (readable). + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 1, 1); // NEXT + write_descriptor(&mem, 1, BUF_BASE + VSOCK_HEADER_SIZE as u64, 5, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + + // Data was forwarded (no stream connected, so just the connection absorbed it). + let conn = dev.connections.get(&(5000, 2695)).unwrap(); + assert_eq!(conn.fwd_cnt(), 5); + } + + // --- TX queue: SHUTDOWN handling --- + + #[test] + fn test_tx_shutdown_closes_connection() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + // Set up a connected connection. + let req_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let mut conn = VsockConnection::new(VSOCK_CID_HOST, 2695, 3, 5000); + conn.handle_request(&req_hdr); + dev.connections.insert((5000, 2695), conn); + + // Send SHUTDOWN with both flags. + let shut_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: packet::VSOCK_OP_SHUTDOWN, + flags: packet::VSOCK_SHUTDOWN_SEND | packet::VSOCK_SHUTDOWN_RECV, + buf_alloc: 0, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &shut_hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + // Connection should be removed. + assert_eq!(dev.connection_count(), 0); + } + + // --- TX queue: RST for unknown connection --- + + #[test] + fn test_tx_rw_to_unknown_port_sends_rst() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let rw_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 9999, + dst_port: 8888, + len: 0, + type_: 1, + op: packet::VSOCK_OP_RW, + flags: 0, + buf_alloc: 0, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &rw_hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, VSOCK_OP_RST); + } + + // --- RX queue: inject pending --- + + #[test] + fn test_inject_rx_writes_header_to_queue() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(128); + + // Set up an RX buffer (device-writable). + write_descriptor(&mem, 0, BUF_BASE, 256, 2, 0); // WRITE flag = 2 + push_avail(&mem, 0, 0); + + // Queue a RESPONSE packet. + let resp = VsockHeader::new_response(2, 2695, 3, 5000, 65536, 0); + dev.rx_pending.push((resp, Vec::new())); + + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(injected); + + // Read back the header from guest memory. + let written = mem.read_bytes(BUF_BASE, VSOCK_HEADER_SIZE); + let read_hdr = VsockHeader::from_bytes(&written.try_into().unwrap()); + assert_eq!(read_hdr.op, packet::VSOCK_OP_RESPONSE); + assert_eq!(read_hdr.src_cid, 2); + assert_eq!(read_hdr.dst_cid, 3); + } + + #[test] + fn test_inject_rx_with_payload() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(128); + + // RX buffer: 256 bytes device-writable. + write_descriptor(&mem, 0, BUF_BASE, 256, 2, 0); + push_avail(&mem, 0, 0); + + let rw = VsockHeader::new_rw(2, 2695, 3, 5000, 5, 65536, 0); + dev.rx_pending.push((rw, b"hello".to_vec())); + + dev.inject_rx(&mut rx_queue, &mem); + + // Check header. + let hdr_bytes = mem.read_bytes(BUF_BASE, VSOCK_HEADER_SIZE); + let hdr = VsockHeader::from_bytes(&hdr_bytes.try_into().unwrap()); + assert_eq!(hdr.op, packet::VSOCK_OP_RW); + assert_eq!(hdr.len, 5); + + // Check payload follows header. + let payload = mem.read_bytes(BUF_BASE + VSOCK_HEADER_SIZE as u64, 5); + assert_eq!(payload, b"hello"); + } + + #[test] + fn test_inject_rx_no_available_buffers() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut rx_queue = setup_queue(128); + // Don't push any available buffers. + + let resp = VsockHeader::new_response(2, 2695, 3, 5000, 65536, 0); + dev.rx_pending.push((resp, Vec::new())); + + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(!injected); + + // Packet should still be pending. + assert_eq!(dev.rx_pending.len(), 1); + } + + // --- Poll default --- + + #[test] + fn test_poll_no_streams_no_pending() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut queues = vec![ + setup_queue(128), // RX + setup_queue(128), // TX + setup_queue(128), // Event + ]; + + let raised = dev.poll(&mut queues, &mem); + assert!(!raised); + } + + // --- Connection lifecycle through TX + RX --- + + #[test] + fn test_connection_lifecycle() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + + // Manually create connection to test data flow without sockets. + let req_hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let mut conn = VsockConnection::new(VSOCK_CID_HOST, 2695, 3, 5000); + conn.handle_request(&req_hdr); + dev.connections.insert((5000, 2695), conn); + + // Enqueue some host->guest data. + dev.connections + .get_mut(&(5000, 2695)) + .unwrap() + .enqueue_tx(b"response data"); + + // Set up RX buffer. + let rx_buf = BUF_BASE + 0x2000; + write_descriptor(&mem, 0, rx_buf, 256, 2, 0); // WRITE + push_avail(&mem, 0, 0); + + let mut rx_queue = setup_queue(128); + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(injected); + + // Verify the injected RW packet. + let hdr_bytes = mem.read_bytes(rx_buf, VSOCK_HEADER_SIZE); + let hdr = VsockHeader::from_bytes(&hdr_bytes.try_into().unwrap()); + assert_eq!(hdr.op, packet::VSOCK_OP_RW); + assert_eq!(hdr.len, 13); + + let payload = mem.read_bytes(rx_buf + VSOCK_HEADER_SIZE as u64, 13); + assert_eq!(payload, b"response data"); + } + + // --- Multiple connections --- + + #[test] + fn test_multiple_connections() { + let mut dev = VirtioVsock::new(3); + + let req1 = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let req2 = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5001, + dst_port: 2696, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + + let mut c1 = VsockConnection::new(VSOCK_CID_HOST, 2695, 3, 5000); + c1.handle_request(&req1); + let mut c2 = VsockConnection::new(VSOCK_CID_HOST, 2696, 3, 5001); + c2.handle_request(&req2); + + dev.connections.insert((5000, 2695), c1); + dev.connections.insert((5001, 2696), c2); + + assert_eq!(dev.connection_count(), 2); + } + + // --- Short descriptor chain --- + + #[test] + fn test_tx_short_header_skipped() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + // Descriptor with only 10 bytes (< 44 byte header). + write_descriptor(&mem, 0, BUF_BASE, 10, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); // Processed (skipped) the entry. + assert!(dev.rx_pending.is_empty()); // No response generated. + } + + // --- Empty chain --- + + #[test] + fn test_tx_empty_chain_skipped() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + // Descriptor with 0 length. + write_descriptor(&mem, 0, BUF_BASE, 0, 0, 0); + push_avail(&mem, 0, 0); + + let processed = dev.process_tx(&mut tx_queue, &mem); + assert!(processed); + } + + // --- Credit update flow --- + + #[test] + fn test_credit_update_injected() { + let mut dev = VirtioVsock::new(3); + let mem = MockMem::new(0x10000); + + let req = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + let mut conn = VsockConnection::new(VSOCK_CID_HOST, 2695, 3, 5000); + conn.handle_request(&req); + conn.handle_credit_request(); + dev.connections.insert((5000, 2695), conn); + + // RX buffer. + let rx_buf = BUF_BASE + 0x2000; + write_descriptor(&mem, 0, rx_buf, 256, 2, 0); + push_avail(&mem, 0, 0); + + let mut rx_queue = setup_queue(128); + let injected = dev.inject_rx(&mut rx_queue, &mem); + assert!(injected); + + let hdr_bytes = mem.read_bytes(rx_buf, VSOCK_HEADER_SIZE); + let hdr = VsockHeader::from_bytes(&hdr_bytes.try_into().unwrap()); + assert_eq!(hdr.op, packet::VSOCK_OP_CREDIT_UPDATE); + } + + // --- Listen and connect with Unix sockets --- + + #[test] + fn test_listen_creates_listener() { + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("listen-test.sock"); + dev.listen_on(2695, sock_path.to_str().unwrap()).unwrap(); + assert_eq!(dev.listeners.len(), 1); + } + + #[test] + fn test_listen_on_two_vsock_ports() { + let mut dev = VirtioVsock::new(3); + let (path1, _dir1) = temp_socket_path("listen1.sock"); + let (path2, _dir2) = temp_socket_path("listen2.sock"); + dev.listen_on(2695, path1.to_str().unwrap()).unwrap(); + dev.listen_on(2696, path2.to_str().unwrap()).unwrap(); + assert_eq!(dev.listeners.len(), 2); + assert!(dev.listeners.contains_key(&2695)); + assert!(dev.listeners.contains_key(&2696)); + } + + #[test] + fn test_listen_with_uds_connect() { + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("listen-connect.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + // Connect a UDS client before the guest sends REQUEST. + let _client = UnixStream::connect(&sock_path).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: vsock_port, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + // Should have RESPONSE and a stream. + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, packet::VSOCK_OP_RESPONSE); + assert_eq!(dev.connection_count(), 1); + assert_eq!(dev.streams.len(), 1); + } + + // --- Poll with UDS data --- + + #[test] + fn test_poll_reads_uds_data() { + use std::io::Write as IoWrite; + + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("poll-data.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + // Connect UDS client. + let mut client = UnixStream::connect(&sock_path).unwrap(); + + // Establish vsock connection. + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: vsock_port, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + dev.process_tx(&mut tx_queue, &mem); + dev.rx_pending.clear(); + + // Send data from UDS client to be picked up by poll. + client.write_all(b"uds data").unwrap(); + client.flush().unwrap(); + + std::thread::sleep(std::time::Duration::from_millis(50)); + + // Poll should read UDS data and queue it. + let mut queues = vec![ + setup_queue(128), // RX + setup_queue(128), // TX + setup_queue(128), // Event + ]; + + let rx_buf = BUF_BASE + 0x4000; + let rx_desc = 0x8000u64; + let rx_avail = 0x8800u64; + let rx_used = 0x9000u64; + queues[0].set_desc_table(rx_desc); + queues[0].set_avail_ring(rx_avail); + queues[0].set_used_ring(rx_used); + + mem.write_u64_at(rx_desc, rx_buf); + mem.write_u32_at(rx_desc + 8, 256); + mem.write_u16_at(rx_desc + 12, 2); // WRITE + mem.write_u16_at(rx_desc + 14, 0); + mem.write_u16_at(rx_avail + 4, 0); + mem.write_u16_at(rx_avail + 2, 1); + + let raised = dev.poll(&mut queues, &mem); + assert!(raised); + + let hdr_bytes = mem.read_bytes(rx_buf, VSOCK_HEADER_SIZE); + let rx_hdr = VsockHeader::from_bytes(&hdr_bytes.try_into().unwrap()); + assert_eq!(rx_hdr.op, packet::VSOCK_OP_RW); + assert_eq!(rx_hdr.len, 8); + + let payload = mem.read_bytes(rx_buf + VSOCK_HEADER_SIZE as u64, 8); + assert_eq!(payload, b"uds data"); + } + + // --- Guest-initiated outbound connection --- + + #[test] + fn test_connect_to_registers_target() { + let mut dev = VirtioVsock::new(3); + dev.connect_to(2696, "/tmp/nonexistent.sock".to_string()); + assert_eq!(dev.connect_targets.len(), 1); + assert!(dev.connect_targets.contains_key(&2696)); + } + + #[test] + fn test_connect_to_outbound_success() { + // Set up a host-side Unix listener to receive the outbound connection. + let (host_sock, _dir) = temp_socket_path("host-outbound.sock"); + let host_listener = UnixListener::bind(&host_sock).unwrap(); + + let mut dev = VirtioVsock::new(3); + dev.connect_to(2696, host_sock.to_str().unwrap().to_string()); + + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2696, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, packet::VSOCK_OP_RESPONSE); + assert_eq!(dev.connection_count(), 1); + assert_eq!(dev.streams.len(), 1); + + // Host listener should have received the connection. + host_listener.set_nonblocking(true).unwrap(); + let accepted = host_listener.accept(); + assert!(accepted.is_ok(), "Host should have received UDS connection"); + } + + #[test] + fn test_connect_to_unreachable_sends_rst() { + let mut dev = VirtioVsock::new(3); + // Nonexistent path — connection will fail. + dev.connect_to(2696, "/tmp/nonexistent-vsock-test-path.sock".to_string()); + + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2696, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, VSOCK_OP_RST); + assert_eq!(dev.connection_count(), 0); + } + + #[test] + fn test_connect_to_preferred_over_listener() { + let (host_sock, _dir) = temp_socket_path("preferred.sock"); + let _host_listener = UnixListener::bind(&host_sock).unwrap(); + + let (listen_sock, _dir2) = temp_socket_path("listen-fallback.sock"); + + let mut dev = VirtioVsock::new(3); + dev.connect_to(2696, host_sock.to_str().unwrap().to_string()); + dev.listen_on(2696, listen_sock.to_str().unwrap()).unwrap(); + + let mem = MockMem::new(0x10000); + let mut tx_queue = setup_queue(128); + + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2696, + len: 0, + type_: 1, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + mem.write_bytes(BUF_BASE, &hdr.to_bytes()); + write_descriptor(&mem, 0, BUF_BASE, VSOCK_HEADER_SIZE as u32, 0, 0); + push_avail(&mem, 0, 0); + + dev.process_tx(&mut tx_queue, &mem); + + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, packet::VSOCK_OP_RESPONSE); + assert_eq!(dev.connection_count(), 1); + assert_eq!(dev.streams.len(), 1); + } + + // --- Host-initiated connections (poll_listeners) --- + + #[test] + fn test_poll_listeners_accepts_and_sends_request() { + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("poll-accept.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + // Host UDS client connects BEFORE any guest action. + let _client = UnixStream::connect(&sock_path).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + dev.poll_listeners(); + + assert_eq!(dev.rx_pending.len(), 1); + assert_eq!(dev.rx_pending[0].0.op, VSOCK_OP_REQUEST); + assert_eq!(dev.rx_pending[0].0.src_cid, VSOCK_CID_HOST); + assert_eq!(dev.rx_pending[0].0.dst_cid, 3); + assert_eq!(dev.rx_pending[0].0.dst_port, vsock_port); + assert!(dev.rx_pending[0].0.src_port >= EPHEMERAL_PORT_START); + assert_eq!(dev.connection_count(), 1); + assert_eq!(dev.streams.len(), 1); + } + + #[test] + fn test_poll_listeners_no_pending_is_noop() { + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("poll-noop.sock"); + dev.listen_on(2695, sock_path.to_str().unwrap()).unwrap(); + + // No client connected. + dev.poll_listeners(); + + assert!(dev.rx_pending.is_empty()); + assert_eq!(dev.connection_count(), 0); + } + + #[test] + fn test_host_initiated_full_lifecycle() { + use std::io::Write as IoWrite; + + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("lifecycle.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + // Step 1: Host client connects. + let mut client = UnixStream::connect(&sock_path).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + // Step 2: VMM accepts and sends REQUEST to guest. + dev.poll_listeners(); + assert_eq!(dev.rx_pending.len(), 1); + let req = &dev.rx_pending[0].0; + assert_eq!(req.op, VSOCK_OP_REQUEST); + let host_ephemeral = req.src_port; + let key = (vsock_port, host_ephemeral); + dev.rx_pending.clear(); + + // Step 3: Guest sends RESPONSE. + let resp = VsockHeader { + src_cid: 3, + dst_cid: VSOCK_CID_HOST, + src_port: vsock_port, + dst_port: host_ephemeral, + len: 0, + type_: 1, + op: packet::VSOCK_OP_RESPONSE, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + dev.handle_guest_packet(&resp, &[]); + + assert_eq!( + dev.connections.get(&key).unwrap().state(), + ConnState::Connected + ); + + // Step 4: Host sends data → forwarded to guest via vsock. + client.write_all(b"hello from host").unwrap(); + client.flush().unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + dev.poll_streams(); + let conn = dev.connections.get(&key).unwrap(); + assert!(conn.tx_buf_len() > 0); + } + + #[test] + fn test_host_initiated_skips_data_during_handshake() { + use std::io::Write as IoWrite; + + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("handshake.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + let mut client = UnixStream::connect(&sock_path).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + dev.poll_listeners(); + let host_ephemeral = dev.rx_pending[0].0.src_port; + let key = (vsock_port, host_ephemeral); + dev.rx_pending.clear(); + + assert_eq!( + dev.connections.get(&key).unwrap().state(), + ConnState::Connecting + ); + + // Host sends data while still Connecting. + client.write_all(b"premature data").unwrap(); + client.flush().unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + // poll_streams should SKIP this stream (not Connected yet). + dev.poll_streams(); + assert_eq!(dev.connections.get(&key).unwrap().tx_buf_len(), 0); + + // Complete the handshake. + let resp = VsockHeader { + src_cid: 3, + dst_cid: VSOCK_CID_HOST, + src_port: vsock_port, + dst_port: host_ephemeral, + len: 0, + type_: 1, + op: packet::VSOCK_OP_RESPONSE, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + dev.handle_guest_packet(&resp, &[]); + assert_eq!( + dev.connections.get(&key).unwrap().state(), + ConnState::Connected + ); + + // NOW poll_streams reads the data. + dev.poll_streams(); + assert!(dev.connections.get(&key).unwrap().tx_buf_len() > 0); + } + + #[test] + fn test_ephemeral_port_allocation() { + let mut dev = VirtioVsock::new(3); + let p1 = dev.alloc_host_port(); + let p2 = dev.alloc_host_port(); + let p3 = dev.alloc_host_port(); + assert_eq!(p1, EPHEMERAL_PORT_START); + assert_eq!(p2, EPHEMERAL_PORT_START + 1); + assert_eq!(p3, EPHEMERAL_PORT_START + 2); + } + + #[test] + fn test_host_initiated_guest_data_to_host_uds() { + use std::io::Read as IoRead; + + let mut dev = VirtioVsock::new(3); + let (sock_path, _dir) = temp_socket_path("guest-data.sock"); + let vsock_port = 2695u32; + dev.listen_on(vsock_port, sock_path.to_str().unwrap()) + .unwrap(); + + let mut client = UnixStream::connect(&sock_path).unwrap(); + client.set_nonblocking(true).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(50)); + + dev.poll_listeners(); + let host_ephemeral = dev.rx_pending[0].0.src_port; + let key = (vsock_port, host_ephemeral); + dev.rx_pending.clear(); + + // Guest RESPONSE. + let resp = VsockHeader { + src_cid: 3, + dst_cid: VSOCK_CID_HOST, + src_port: vsock_port, + dst_port: host_ephemeral, + len: 0, + type_: 1, + op: packet::VSOCK_OP_RESPONSE, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + dev.handle_guest_packet(&resp, &[]); + + // Guest sends data (RW) → should be forwarded to host Unix stream. + let rw_hdr = VsockHeader { + src_cid: 3, + dst_cid: VSOCK_CID_HOST, + src_port: vsock_port, + dst_port: host_ephemeral, + len: 11, + type_: 1, + op: packet::VSOCK_OP_RW, + flags: 0, + buf_alloc: 32768, + fwd_cnt: 0, + }; + dev.handle_guest_packet(&rw_hdr, b"hello guest"); + + // Read from UDS client. + std::thread::sleep(std::time::Duration::from_millis(50)); + let mut buf = [0u8; 128]; + let n = client.read(&mut buf).unwrap(); + assert_eq!(&buf[..n], b"hello guest"); + } +} diff --git a/src/vmm/src/windows/devices/virtio/vsock/packet.rs b/src/vmm/src/windows/devices/virtio/vsock/packet.rs new file mode 100644 index 000000000..4913612cd --- /dev/null +++ b/src/vmm/src/windows/devices/virtio/vsock/packet.rs @@ -0,0 +1,486 @@ +//! Virtio-vsock packet header (virtio spec v1.2 Section 5.10.6). +//! +//! The 44-byte header is prepended to every vsock packet in the +//! TX and RX virtqueues. It carries addressing, flow control credits, +//! and operation codes for the vsock connection protocol. + +use super::super::super::super::error::{Result, WkrunError}; +use super::super::queue::GuestMemoryAccessor; + +// --- CID constants --- + +/// Well-known CID for the host (hypervisor). +pub const VSOCK_CID_HOST: u64 = 2; + +// --- Vsock type --- + +/// Stream transport (SOCK_STREAM equivalent). +pub const VIRTIO_VSOCK_TYPE_STREAM: u16 = 1; + +// --- Vsock operations (spec 5.10.6.6) --- + +/// Invalid operation. +pub const VSOCK_OP_INVALID: u16 = 0; +/// Connection request (guest -> host). +pub const VSOCK_OP_REQUEST: u16 = 1; +/// Connection accepted (host -> guest). +pub const VSOCK_OP_RESPONSE: u16 = 2; +/// Connection reset / refused. +pub const VSOCK_OP_RST: u16 = 3; +/// Graceful shutdown. +pub const VSOCK_OP_SHUTDOWN: u16 = 4; +/// Data transfer. +pub const VSOCK_OP_RW: u16 = 5; +/// Credit update (no payload). +pub const VSOCK_OP_CREDIT_UPDATE: u16 = 6; +/// Credit request (ask peer to send credit update). +pub const VSOCK_OP_CREDIT_REQUEST: u16 = 7; + +// --- Shutdown flags --- + +/// Shutdown flag: no more data to send. +pub const VSOCK_SHUTDOWN_SEND: u32 = 1; +/// Shutdown flag: no more data to receive. +pub const VSOCK_SHUTDOWN_RECV: u32 = 2; + +/// Size of the vsock packet header in bytes. +pub const VSOCK_HEADER_SIZE: usize = 44; + +/// Virtio-vsock packet header (44 bytes, little-endian). +/// +/// Layout (spec 5.10.6): +/// offset 0: src_cid (u64) +/// offset 8: dst_cid (u64) +/// offset 16: src_port (u32) +/// offset 20: dst_port (u32) +/// offset 24: len (u32) - payload length +/// offset 28: type_ (u16) - VIRTIO_VSOCK_TYPE_STREAM +/// offset 30: op (u16) - operation code +/// offset 32: flags (u32) - operation-specific flags +/// offset 36: buf_alloc (u32) - credit: total buffer space +/// offset 40: fwd_cnt (u32) - credit: bytes consumed so far +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct VsockHeader { + pub src_cid: u64, + pub dst_cid: u64, + pub src_port: u32, + pub dst_port: u32, + pub len: u32, + pub type_: u16, + pub op: u16, + pub flags: u32, + pub buf_alloc: u32, + pub fwd_cnt: u32, +} + +impl VsockHeader { + /// Read a vsock header from guest memory at the given address. + pub fn read_from(mem: &dyn GuestMemoryAccessor, addr: u64) -> Result { + let mut buf = [0u8; VSOCK_HEADER_SIZE]; + mem.read_at(addr, &mut buf)?; + Ok(Self::from_bytes(&buf)) + } + + /// Write this vsock header to guest memory at the given address. + pub fn write_to(&self, mem: &dyn GuestMemoryAccessor, addr: u64) -> Result<()> { + let buf = self.to_bytes(); + mem.write_at(addr, &buf) + } + + /// Parse a vsock header from a 44-byte buffer. + pub fn from_bytes(buf: &[u8; VSOCK_HEADER_SIZE]) -> Self { + VsockHeader { + src_cid: u64::from_le_bytes(buf[0..8].try_into().unwrap()), + dst_cid: u64::from_le_bytes(buf[8..16].try_into().unwrap()), + src_port: u32::from_le_bytes(buf[16..20].try_into().unwrap()), + dst_port: u32::from_le_bytes(buf[20..24].try_into().unwrap()), + len: u32::from_le_bytes(buf[24..28].try_into().unwrap()), + type_: u16::from_le_bytes(buf[28..30].try_into().unwrap()), + op: u16::from_le_bytes(buf[30..32].try_into().unwrap()), + flags: u32::from_le_bytes(buf[32..36].try_into().unwrap()), + buf_alloc: u32::from_le_bytes(buf[36..40].try_into().unwrap()), + fwd_cnt: u32::from_le_bytes(buf[40..44].try_into().unwrap()), + } + } + + /// Serialize this header to a 44-byte buffer. + pub fn to_bytes(&self) -> [u8; VSOCK_HEADER_SIZE] { + let mut buf = [0u8; VSOCK_HEADER_SIZE]; + buf[0..8].copy_from_slice(&self.src_cid.to_le_bytes()); + buf[8..16].copy_from_slice(&self.dst_cid.to_le_bytes()); + buf[16..20].copy_from_slice(&self.src_port.to_le_bytes()); + buf[20..24].copy_from_slice(&self.dst_port.to_le_bytes()); + buf[24..28].copy_from_slice(&self.len.to_le_bytes()); + buf[28..30].copy_from_slice(&self.type_.to_le_bytes()); + buf[30..32].copy_from_slice(&self.op.to_le_bytes()); + buf[32..36].copy_from_slice(&self.flags.to_le_bytes()); + buf[36..40].copy_from_slice(&self.buf_alloc.to_le_bytes()); + buf[40..44].copy_from_slice(&self.fwd_cnt.to_le_bytes()); + buf + } + + /// Create a REQUEST header (host -> guest) for a host-initiated connection. + pub fn new_request( + src_cid: u64, + src_port: u32, + dst_cid: u64, + dst_port: u32, + buf_alloc: u32, + fwd_cnt: u32, + ) -> Self { + VsockHeader { + src_cid, + dst_cid, + src_port, + dst_port, + len: 0, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc, + fwd_cnt, + } + } + + /// Create a RESPONSE header (host -> guest) for a given REQUEST. + pub fn new_response( + src_cid: u64, + src_port: u32, + dst_cid: u64, + dst_port: u32, + buf_alloc: u32, + fwd_cnt: u32, + ) -> Self { + VsockHeader { + src_cid, + dst_cid, + src_port, + dst_port, + len: 0, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_RESPONSE, + flags: 0, + buf_alloc, + fwd_cnt, + } + } + + /// Create an RW (data) header. + pub fn new_rw( + src_cid: u64, + src_port: u32, + dst_cid: u64, + dst_port: u32, + payload_len: u32, + buf_alloc: u32, + fwd_cnt: u32, + ) -> Self { + VsockHeader { + src_cid, + dst_cid, + src_port, + dst_port, + len: payload_len, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_RW, + flags: 0, + buf_alloc, + fwd_cnt, + } + } + + /// Create a RST header. + pub fn new_rst(src_cid: u64, src_port: u32, dst_cid: u64, dst_port: u32) -> Self { + VsockHeader { + src_cid, + dst_cid, + src_port, + dst_port, + len: 0, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_RST, + flags: 0, + buf_alloc: 0, + fwd_cnt: 0, + } + } + + /// Create a SHUTDOWN header. + pub fn new_shutdown( + src_cid: u64, + src_port: u32, + dst_cid: u64, + dst_port: u32, + flags: u32, + ) -> Self { + VsockHeader { + src_cid, + dst_cid, + src_port, + dst_port, + len: 0, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_SHUTDOWN, + flags, + buf_alloc: 0, + fwd_cnt: 0, + } + } + + /// Create a CREDIT_UPDATE header. + pub fn new_credit_update( + src_cid: u64, + src_port: u32, + dst_cid: u64, + dst_port: u32, + buf_alloc: u32, + fwd_cnt: u32, + ) -> Self { + VsockHeader { + src_cid, + dst_cid, + src_port, + dst_port, + len: 0, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_CREDIT_UPDATE, + flags: 0, + buf_alloc, + fwd_cnt, + } + } + + /// Validate that this header has a known operation and stream type. + pub fn validate(&self) -> Result<()> { + if self.type_ != VIRTIO_VSOCK_TYPE_STREAM { + return Err(WkrunError::Device(format!( + "unsupported vsock type: {} (expected stream=1)", + self.type_ + ))); + } + if self.op > VSOCK_OP_CREDIT_REQUEST { + return Err(WkrunError::Device(format!( + "unknown vsock operation: {}", + self.op + ))); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::super::error::Result; + use super::*; + use std::cell::RefCell; + + struct MockMem(RefCell>); + impl MockMem { + fn new(size: usize) -> Self { + MockMem(RefCell::new(vec![0u8; size])) + } + } + impl GuestMemoryAccessor for MockMem { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + let a = addr as usize; + let data = self.0.borrow(); + buf.copy_from_slice(&data[a..a + buf.len()]); + Ok(()) + } + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + let a = addr as usize; + let mut mem = self.0.borrow_mut(); + mem[a..a + data.len()].copy_from_slice(data); + Ok(()) + } + } + + #[test] + fn test_header_size() { + assert_eq!(VSOCK_HEADER_SIZE, 44); + } + + #[test] + fn test_roundtrip_bytes() { + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 1234, + dst_port: 2695, + len: 100, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_RW, + flags: 0, + buf_alloc: 65536, + fwd_cnt: 512, + }; + let bytes = hdr.to_bytes(); + assert_eq!(bytes.len(), VSOCK_HEADER_SIZE); + let parsed = VsockHeader::from_bytes(&bytes); + assert_eq!(parsed, hdr); + } + + #[test] + fn test_field_offsets() { + let hdr = VsockHeader { + src_cid: 0x0102_0304_0506_0708, + dst_cid: 0x090A_0B0C_0D0E_0F10, + src_port: 0x11121314, + dst_port: 0x15161718, + len: 0x191A1B1C, + type_: 0x1D1E, + op: 0x1F20, + flags: 0x21222324, + buf_alloc: 0x25262728, + fwd_cnt: 0x292A2B2C, + }; + let buf = hdr.to_bytes(); + + // Verify each field starts at the correct offset. + assert_eq!( + u64::from_le_bytes(buf[0..8].try_into().unwrap()), + hdr.src_cid + ); + assert_eq!( + u64::from_le_bytes(buf[8..16].try_into().unwrap()), + hdr.dst_cid + ); + assert_eq!( + u32::from_le_bytes(buf[16..20].try_into().unwrap()), + hdr.src_port + ); + assert_eq!( + u32::from_le_bytes(buf[20..24].try_into().unwrap()), + hdr.dst_port + ); + assert_eq!(u32::from_le_bytes(buf[24..28].try_into().unwrap()), hdr.len); + assert_eq!( + u16::from_le_bytes(buf[28..30].try_into().unwrap()), + hdr.type_ + ); + assert_eq!(u16::from_le_bytes(buf[30..32].try_into().unwrap()), hdr.op); + assert_eq!( + u32::from_le_bytes(buf[32..36].try_into().unwrap()), + hdr.flags + ); + assert_eq!( + u32::from_le_bytes(buf[36..40].try_into().unwrap()), + hdr.buf_alloc + ); + assert_eq!( + u32::from_le_bytes(buf[40..44].try_into().unwrap()), + hdr.fwd_cnt + ); + } + + #[test] + fn test_read_write_guest_memory() { + let mem = MockMem::new(256); + let hdr = VsockHeader { + src_cid: 3, + dst_cid: 2, + src_port: 5000, + dst_port: 2695, + len: 0, + type_: VIRTIO_VSOCK_TYPE_STREAM, + op: VSOCK_OP_REQUEST, + flags: 0, + buf_alloc: 4096, + fwd_cnt: 0, + }; + hdr.write_to(&mem, 0).unwrap(); + let read_back = VsockHeader::read_from(&mem, 0).unwrap(); + assert_eq!(read_back, hdr); + } + + #[test] + fn test_new_request() { + let hdr = VsockHeader::new_request(2, 49152, 3, 2695, 65536, 0); + assert_eq!(hdr.src_cid, 2); + assert_eq!(hdr.dst_cid, 3); + assert_eq!(hdr.src_port, 49152); + assert_eq!(hdr.dst_port, 2695); + assert_eq!(hdr.len, 0); + assert_eq!(hdr.type_, VIRTIO_VSOCK_TYPE_STREAM); + assert_eq!(hdr.op, VSOCK_OP_REQUEST); + assert_eq!(hdr.buf_alloc, 65536); + assert_eq!(hdr.fwd_cnt, 0); + } + + #[test] + fn test_new_response() { + let hdr = VsockHeader::new_response(2, 2695, 3, 5000, 65536, 0); + assert_eq!(hdr.src_cid, 2); + assert_eq!(hdr.dst_cid, 3); + assert_eq!(hdr.src_port, 2695); + assert_eq!(hdr.dst_port, 5000); + assert_eq!(hdr.len, 0); + assert_eq!(hdr.type_, VIRTIO_VSOCK_TYPE_STREAM); + assert_eq!(hdr.op, VSOCK_OP_RESPONSE); + assert_eq!(hdr.buf_alloc, 65536); + assert_eq!(hdr.fwd_cnt, 0); + } + + #[test] + fn test_new_rw() { + let hdr = VsockHeader::new_rw(2, 2695, 3, 5000, 128, 65536, 64); + assert_eq!(hdr.op, VSOCK_OP_RW); + assert_eq!(hdr.len, 128); + assert_eq!(hdr.buf_alloc, 65536); + assert_eq!(hdr.fwd_cnt, 64); + } + + #[test] + fn test_new_rst() { + let hdr = VsockHeader::new_rst(2, 2695, 3, 5000); + assert_eq!(hdr.op, VSOCK_OP_RST); + assert_eq!(hdr.len, 0); + assert_eq!(hdr.buf_alloc, 0); + assert_eq!(hdr.fwd_cnt, 0); + } + + #[test] + fn test_new_shutdown() { + let hdr = + VsockHeader::new_shutdown(3, 5000, 2, 2695, VSOCK_SHUTDOWN_SEND | VSOCK_SHUTDOWN_RECV); + assert_eq!(hdr.op, VSOCK_OP_SHUTDOWN); + assert_eq!(hdr.flags, 3); + } + + #[test] + fn test_new_credit_update() { + let hdr = VsockHeader::new_credit_update(2, 2695, 3, 5000, 32768, 1024); + assert_eq!(hdr.op, VSOCK_OP_CREDIT_UPDATE); + assert_eq!(hdr.buf_alloc, 32768); + assert_eq!(hdr.fwd_cnt, 1024); + } + + #[test] + fn test_validate_valid() { + let hdr = VsockHeader::new_response(2, 2695, 3, 5000, 65536, 0); + assert!(hdr.validate().is_ok()); + } + + #[test] + fn test_validate_bad_type() { + let mut hdr = VsockHeader::new_response(2, 2695, 3, 5000, 65536, 0); + hdr.type_ = 99; + assert!(hdr.validate().is_err()); + } + + #[test] + fn test_validate_bad_op() { + let mut hdr = VsockHeader::new_response(2, 2695, 3, 5000, 65536, 0); + hdr.op = 99; + assert!(hdr.validate().is_err()); + } + + #[test] + fn test_zero_header() { + let buf = [0u8; VSOCK_HEADER_SIZE]; + let hdr = VsockHeader::from_bytes(&buf); + assert_eq!(hdr.src_cid, 0); + assert_eq!(hdr.dst_cid, 0); + assert_eq!(hdr.op, VSOCK_OP_INVALID); + } +} diff --git a/src/vmm/src/windows/error.rs b/src/vmm/src/windows/error.rs new file mode 100644 index 000000000..ac363f95e --- /dev/null +++ b/src/vmm/src/windows/error.rs @@ -0,0 +1,116 @@ +//! Error types for the Windows WHPX backend. + +/// Result type for WHPX operations. +pub type Result = std::result::Result; + +/// Errors that can occur in the WHPX backend. +#[derive(Debug, thiserror::Error)] +pub enum WkrunError { + /// WHPX API call failed with an HRESULT. + #[error("WHPX API call failed: {function} returned 0x{hresult:08X}")] + WhpxApi { + function: &'static str, + hresult: u32, + }, + + /// WHPX/Hyper-V is not available on this system. + #[error("WHPX not available: {0}")] + WhpxUnavailable(String), + + /// Invalid VM context ID. + #[error("invalid context ID: {0}")] + InvalidContext(u32), + + /// Context ID already in use. + #[error("context ID {0} already exists")] + ContextExists(u32), + + /// VM configuration error. + #[error("VM configuration error: {0}")] + Config(String), + + /// Guest memory error. + #[error("guest memory error: {0}")] + Memory(String), + + /// vCPU error. + #[error("vCPU error: {0}")] + Vcpu(String), + + /// I/O error. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// Boot/kernel loading error. + #[error("boot error: {0}")] + Boot(String), + + /// Device emulation error. + #[error("device error: {0}")] + Device(String), + + /// VM is not in the expected state for this operation. + #[error("invalid VM state: expected {expected}, got {actual}")] + InvalidState { + expected: &'static str, + actual: String, + }, +} + +impl WkrunError { + /// Create a WHPX API error from a function name and HRESULT. + pub fn whpx(function: &'static str, hresult: u32) -> Self { + WkrunError::WhpxApi { function, hresult } + } +} + +/// Checks an HRESULT and returns an error if it indicates failure. +/// HRESULT values with the high bit set indicate failure. +#[cfg(target_os = "windows")] +pub fn check_hresult(function: &'static str, hr: i32) -> Result<()> { + if hr < 0 { + Err(WkrunError::whpx(function, hr as u32)) + } else { + Ok(()) + } +} + +/// Return code for the C API: 0 = success, negative = error. +#[repr(i32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CApiResult { + Success = 0, + InvalidContext = -1, + InvalidArgument = -2, + WhpxError = -3, + MemoryError = -4, + BootError = -5, + DeviceError = -6, + StateError = -7, + IoError = -8, + Unknown = -99, +} + +impl From<&WkrunError> for CApiResult { + fn from(err: &WkrunError) -> Self { + match err { + WkrunError::InvalidContext(_) => CApiResult::InvalidContext, + WkrunError::ContextExists(_) => CApiResult::InvalidContext, + WkrunError::Config(_) => CApiResult::InvalidArgument, + WkrunError::WhpxApi { .. } => CApiResult::WhpxError, + WkrunError::WhpxUnavailable(_) => CApiResult::WhpxError, + WkrunError::Memory(_) => CApiResult::MemoryError, + WkrunError::Boot(_) => CApiResult::BootError, + WkrunError::Device(_) => CApiResult::DeviceError, + WkrunError::InvalidState { .. } => CApiResult::StateError, + WkrunError::Vcpu(_) => CApiResult::DeviceError, + WkrunError::Io(_) => CApiResult::IoError, + } + } +} + +impl From<&WkrunError> for i32 { + fn from(err: &WkrunError) -> Self { + CApiResult::from(err) as i32 + } +} diff --git a/src/vmm/src/windows/insn.rs b/src/vmm/src/windows/insn.rs new file mode 100644 index 000000000..110ae1e19 --- /dev/null +++ b/src/vmm/src/windows/insn.rs @@ -0,0 +1,662 @@ +//! Minimal x86_64 instruction decoder for MMIO emulation. +//! +//! Decodes the instruction bytes provided by WHPX memory access exits +//! to extract write data, access size, and destination register for reads. +//! +//! Only handles the instruction patterns Linux generates for MMIO: +//! - MOV r/m, reg (0x88/0x89) — writeb/writel/writeq +//! - MOV reg, r/m (0x8A/0x8B) — readb/readl/readq +//! - MOV r/m, imm (0xC6/0xC7) — writeb/writel with immediate +//! - MOVZX reg, r/m (0x0F 0xB6/0xB7) — readb/readw with zero-extend + +use super::error::{Result, WkrunError}; +use super::types::StandardRegisters; + +/// Decoded MMIO instruction information. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct MmioInsn { + /// Number of bytes accessed (1, 2, 4, or 8). + pub access_size: u8, + /// For writes: the value being written. + pub data: u64, + /// Total instruction length in bytes. + pub len: u8, + /// Whether this is a write (true) or read (false). + pub is_write: bool, + /// For reads: which general-purpose register receives the value (0=RAX..15=R15). + pub gpr_index: Option, +} + +/// REX prefix bit fields. +struct Rex { + /// REX.W — 64-bit operand size. + w: bool, + /// REX.R — extends ModRM reg field. + r: bool, +} + +impl Rex { + fn none() -> Self { + Rex { w: false, r: false } + } + + fn from_byte(byte: u8) -> Self { + Rex { + w: byte & 0x08 != 0, + r: byte & 0x04 != 0, + } + } +} + +/// Read a general-purpose register value by index (0=RAX, 1=RCX, ..., 15=R15). +/// +/// The index matches x86_64 ModRM/SIB encoding: +/// 0=RAX, 1=RCX, 2=RDX, 3=RBX, 4=RSP, 5=RBP, 6=RSI, 7=RDI, +/// 8=R8, 9=R9, 10=R10, 11=R11, 12=R12, 13=R13, 14=R14, 15=R15 +pub fn read_gpr(regs: &StandardRegisters, index: u8) -> u64 { + match index { + 0 => regs.rax, + 1 => regs.rcx, + 2 => regs.rdx, + 3 => regs.rbx, + 4 => regs.rsp, + 5 => regs.rbp, + 6 => regs.rsi, + 7 => regs.rdi, + 8 => regs.r8, + 9 => regs.r9, + 10 => regs.r10, + 11 => regs.r11, + 12 => regs.r12, + 13 => regs.r13, + 14 => regs.r14, + 15 => regs.r15, + _ => 0, + } +} + +/// Calculate the length of the ModRM addressing mode (displacement bytes). +/// +/// For MMIO, the ModRM byte encodes a memory operand. We need to know +/// how many bytes the addressing mode consumes to find the instruction length. +fn modrm_disp_len(modrm: u8, has_sib: bool) -> usize { + let mod_field = modrm >> 6; + let rm = modrm & 0x07; + + match mod_field { + 0b00 => { + if rm == 0b101 { + // [RIP+disp32] or [disp32] — 4-byte displacement + 4 + } else if rm == 0b100 && has_sib { + // SIB byte present, check SIB base + // For simplicity, return 0 (base case) — SIB with mod=00 and base=101 has disp32 + 0 // Will be handled by caller checking SIB + } else { + 0 + } + } + 0b01 => 1, // [reg+disp8] + 0b10 => 4, // [reg+disp32] + _ => 0, // mod=11 is register-to-register (shouldn't happen for MMIO) + } +} + +/// Calculate total bytes consumed by ModRM + SIB + displacement. +fn addressing_mode_len(bytes: &[u8], offset: usize) -> usize { + if offset >= bytes.len() { + return 0; + } + let modrm = bytes[offset]; + let mod_field = modrm >> 6; + let rm = modrm & 0x07; + + // Start with 1 byte for ModRM itself. + let mut len = 1; + + // Check for SIB byte (rm=100 with mod != 11). + let has_sib = rm == 0b100 && mod_field != 0b11; + if has_sib { + len += 1; // SIB byte + + // Check SIB base for special disp32 case. + if offset + 1 < bytes.len() { + let sib = bytes[offset + 1]; + let base = sib & 0x07; + if mod_field == 0b00 && base == 0b101 { + len += 4; // disp32 with SIB + return len; + } + } + } + + // Add displacement bytes. + len += modrm_disp_len(modrm, has_sib); + + len +} + +/// Decode an MMIO instruction from raw instruction bytes. +/// +/// `bytes` contains the instruction bytes from the WHPX exit context. +/// `regs` contains the current vCPU register state (needed to extract +/// write values from source registers). +/// +/// Returns the decoded instruction information, or an error if the +/// instruction pattern is not recognized. +pub fn decode_mmio_insn(bytes: &[u8], regs: &StandardRegisters) -> Result { + if bytes.is_empty() { + return Err(WkrunError::Device("empty instruction bytes".into())); + } + + let mut pos = 0; + let mut rex = Rex::none(); + let mut has_operand_size_prefix = false; + + // Parse prefixes. + loop { + if pos >= bytes.len() { + return Err(WkrunError::Device("instruction too short".into())); + } + match bytes[pos] { + 0x66 => { + has_operand_size_prefix = true; + pos += 1; + } + 0x67 => { + // Address-size prefix — skip but don't change operand size. + pos += 1; + } + 0xF2 | 0xF3 => { + // REP/REPNE prefix — skip. + pos += 1; + } + b @ 0x40..=0x4F => { + rex = Rex::from_byte(b); + pos += 1; + break; // REX must be last prefix. + } + _ => break, + } + } + + if pos >= bytes.len() { + return Err(WkrunError::Device( + "instruction too short after prefixes".into(), + )); + } + + let opcode = bytes[pos]; + pos += 1; + + match opcode { + // MOV r/m8, reg8 (write, 8-bit) + 0x88 => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOV r/m8,r8: missing ModRM".into())); + } + let modrm = bytes[pos]; + let reg = ((modrm >> 3) & 0x07) | if rex.r { 8 } else { 0 }; + let addr_len = addressing_mode_len(bytes, pos); + let value = read_gpr(regs, reg) & 0xFF; + Ok(MmioInsn { + access_size: 1, + data: value, + len: (pos + addr_len) as u8, + is_write: true, + gpr_index: None, + }) + } + + // MOV r/m16/32/64, reg16/32/64 (write) + 0x89 => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOV r/m,r: missing ModRM".into())); + } + let modrm = bytes[pos]; + let reg = ((modrm >> 3) & 0x07) | if rex.r { 8 } else { 0 }; + let addr_len = addressing_mode_len(bytes, pos); + let access_size = if rex.w { + 8 + } else if has_operand_size_prefix { + 2 + } else { + 4 + }; + let mask = match access_size { + 2 => 0xFFFF, + 4 => 0xFFFF_FFFF, + 8 => u64::MAX, + _ => 0xFF, + }; + let value = read_gpr(regs, reg) & mask; + Ok(MmioInsn { + access_size, + data: value, + len: (pos + addr_len) as u8, + is_write: true, + gpr_index: None, + }) + } + + // MOV reg8, r/m8 (read, 8-bit) + 0x8A => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOV r8,r/m8: missing ModRM".into())); + } + let modrm = bytes[pos]; + let reg = ((modrm >> 3) & 0x07) | if rex.r { 8 } else { 0 }; + let addr_len = addressing_mode_len(bytes, pos); + Ok(MmioInsn { + access_size: 1, + data: 0, + len: (pos + addr_len) as u8, + is_write: false, + gpr_index: Some(reg), + }) + } + + // MOV reg16/32/64, r/m16/32/64 (read) + 0x8B => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOV r,r/m: missing ModRM".into())); + } + let modrm = bytes[pos]; + let reg = ((modrm >> 3) & 0x07) | if rex.r { 8 } else { 0 }; + let addr_len = addressing_mode_len(bytes, pos); + let access_size = if rex.w { + 8 + } else if has_operand_size_prefix { + 2 + } else { + 4 + }; + Ok(MmioInsn { + access_size, + data: 0, + len: (pos + addr_len) as u8, + is_write: false, + gpr_index: Some(reg), + }) + } + + // MOV r/m8, imm8 (write, 8-bit immediate) + 0xC6 => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOV r/m8,imm8: missing ModRM".into())); + } + let addr_len = addressing_mode_len(bytes, pos); + let imm_pos = pos + addr_len; + if imm_pos >= bytes.len() { + return Err(WkrunError::Device( + "MOV r/m8,imm8: missing immediate".into(), + )); + } + let value = bytes[imm_pos] as u64; + Ok(MmioInsn { + access_size: 1, + data: value, + len: (imm_pos + 1) as u8, + is_write: true, + gpr_index: None, + }) + } + + // MOV r/m16/32, imm16/32 (write, immediate) + 0xC7 => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOV r/m,imm: missing ModRM".into())); + } + let addr_len = addressing_mode_len(bytes, pos); + let imm_pos = pos + addr_len; + let (access_size, imm_len) = if has_operand_size_prefix { + (2u8, 2usize) + } else { + (4u8, 4usize) + }; + if imm_pos + imm_len > bytes.len() { + return Err(WkrunError::Device("MOV r/m,imm: missing immediate".into())); + } + let value = match imm_len { + 2 => u16::from_le_bytes([bytes[imm_pos], bytes[imm_pos + 1]]) as u64, + 4 => u32::from_le_bytes([ + bytes[imm_pos], + bytes[imm_pos + 1], + bytes[imm_pos + 2], + bytes[imm_pos + 3], + ]) as u64, + _ => unreachable!(), + }; + Ok(MmioInsn { + access_size, + data: value, + len: (imm_pos + imm_len) as u8, + is_write: true, + gpr_index: None, + }) + } + + // Two-byte opcodes (0x0F prefix). + 0x0F => { + if pos >= bytes.len() { + return Err(WkrunError::Device( + "0x0F: missing second opcode byte".into(), + )); + } + let opcode2 = bytes[pos]; + pos += 1; + + match opcode2 { + // MOVZX reg, r/m8 (read, 8-bit zero-extended to 32/64) + 0xB6 => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOVZX r,r/m8: missing ModRM".into())); + } + let modrm = bytes[pos]; + let reg = ((modrm >> 3) & 0x07) | if rex.r { 8 } else { 0 }; + let addr_len = addressing_mode_len(bytes, pos); + Ok(MmioInsn { + access_size: 1, + data: 0, + len: (pos + addr_len) as u8, + is_write: false, + gpr_index: Some(reg), + }) + } + + // MOVZX reg, r/m16 (read, 16-bit zero-extended to 32/64) + 0xB7 => { + if pos >= bytes.len() { + return Err(WkrunError::Device("MOVZX r,r/m16: missing ModRM".into())); + } + let modrm = bytes[pos]; + let reg = ((modrm >> 3) & 0x07) | if rex.r { 8 } else { 0 }; + let addr_len = addressing_mode_len(bytes, pos); + Ok(MmioInsn { + access_size: 2, + data: 0, + len: (pos + addr_len) as u8, + is_write: false, + gpr_index: Some(reg), + }) + } + + _ => Err(WkrunError::Device(format!( + "unrecognized 0x0F opcode: 0x{:02X} (bytes: {:02X?})", + opcode2, bytes + ))), + } + } + + _ => Err(WkrunError::Device(format!( + "unrecognized MMIO opcode: 0x{:02X} (bytes: {:02X?})", + opcode, bytes + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_regs() -> StandardRegisters { + StandardRegisters { + rax: 0xDEAD_BEEF_CAFE_BABE, + rcx: 0x1111_1111_1111_1111, + rdx: 0x2222_2222_2222_2222, + rbx: 0x3333_3333_3333_3333, + rsp: 0x4444_4444_4444_4444, + rbp: 0x5555_5555_5555_5555, + rsi: 0x6666_6666_6666_6666, + rdi: 0x7777_7777_7777_7777, + r8: 0x8888_8888_8888_8888, + r9: 0x9999_9999_9999_9999, + r10: 0xAAAA_AAAA_AAAA_AAAA, + r11: 0xBBBB_BBBB_BBBB_BBBB, + r12: 0xCCCC_CCCC_CCCC_CCCC, + r13: 0xDDDD_DDDD_DDDD_DDDD, + r14: 0xEEEE_EEEE_EEEE_EEEE, + r15: 0xFFFF_FFFF_FFFF_FFFF, + rip: 0, + rflags: 0, + } + } + + // --- MOV r/m32, reg (0x89) — writel --- + + #[test] + fn test_mov_dword_ptr_eax() { + // mov dword [rdi], eax → 89 07 + // ModRM: mod=00, reg=000(eax), r/m=111(rdi) + let bytes = [0x89, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.data, regs.rax & 0xFFFF_FFFF); + assert_eq!(insn.len, 2); + assert_eq!(insn.gpr_index, None); + } + + #[test] + fn test_mov_dword_ptr_ecx_disp8() { + // mov dword [rdi+0x10], ecx → 89 4F 10 + // ModRM: mod=01, reg=001(ecx), r/m=111(rdi) + let bytes = [0x89, 0x4F, 0x10]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.data, regs.rcx & 0xFFFF_FFFF); + assert_eq!(insn.len, 3); + } + + // --- MOV r/m64, reg (REX.W 0x89) — writeq --- + + #[test] + fn test_mov_qword_ptr_rax() { + // mov qword [rdi], rax → 48 89 07 + // REX.W=1, ModRM: mod=00, reg=000(rax), r/m=111(rdi) + let bytes = [0x48, 0x89, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 8); + assert_eq!(insn.data, regs.rax); + assert_eq!(insn.len, 3); + } + + // --- MOV r/m8, reg8 (0x88) — writeb --- + + #[test] + fn test_mov_byte_ptr_al() { + // mov byte [rdi], al → 88 07 + let bytes = [0x88, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 1); + assert_eq!(insn.data, regs.rax & 0xFF); + assert_eq!(insn.len, 2); + } + + // --- MOV r/m16, reg16 (0x66 0x89) — writew --- + + #[test] + fn test_mov_word_ptr_ax() { + // mov word [rdi], ax → 66 89 07 + let bytes = [0x66, 0x89, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 2); + assert_eq!(insn.data, regs.rax & 0xFFFF); + assert_eq!(insn.len, 3); + } + + // --- MOV reg32, r/m32 (0x8B) — readl --- + + #[test] + fn test_mov_eax_dword_ptr() { + // mov eax, dword [rdi] → 8B 07 + let bytes = [0x8B, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(!insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.gpr_index, Some(0)); // RAX + assert_eq!(insn.len, 2); + } + + // --- MOV reg64, r/m64 (REX.W 0x8B) — readq --- + + #[test] + fn test_mov_rax_qword_ptr() { + // mov rax, qword [rdi] → 48 8B 07 + let bytes = [0x48, 0x8B, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(!insn.is_write); + assert_eq!(insn.access_size, 8); + assert_eq!(insn.gpr_index, Some(0)); // RAX + assert_eq!(insn.len, 3); + } + + // --- MOV r/m32, imm32 (0xC7) — writel with immediate --- + + #[test] + fn test_mov_dword_ptr_imm32() { + // mov dword [rdi], 0x12345678 → C7 07 78 56 34 12 + let bytes = [0xC7, 0x07, 0x78, 0x56, 0x34, 0x12]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.data, 0x12345678); + assert_eq!(insn.len, 6); + } + + // --- MOV r/m8, imm8 (0xC6) — writeb with immediate --- + + #[test] + fn test_mov_byte_ptr_imm8() { + // mov byte [rdi], 0xAB → C6 07 AB + let bytes = [0xC6, 0x07, 0xAB]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 1); + assert_eq!(insn.data, 0xAB); + assert_eq!(insn.len, 3); + } + + // --- MOVZX reg, r/m8 (0x0F 0xB6) — readb --- + + #[test] + fn test_movzx_eax_byte_ptr() { + // movzx eax, byte [rdi] → 0F B6 07 + let bytes = [0x0F, 0xB6, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(!insn.is_write); + assert_eq!(insn.access_size, 1); + assert_eq!(insn.gpr_index, Some(0)); // EAX + assert_eq!(insn.len, 3); + } + + // --- MOVZX reg, r/m16 (0x0F 0xB7) — readw --- + + #[test] + fn test_movzx_eax_word_ptr() { + // movzx eax, word [rdi] → 0F B7 07 + let bytes = [0x0F, 0xB7, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(!insn.is_write); + assert_eq!(insn.access_size, 2); + assert_eq!(insn.gpr_index, Some(0)); // EAX + assert_eq!(insn.len, 3); + } + + // --- REX.R extended registers --- + + #[test] + fn test_mov_dword_ptr_r8d() { + // mov dword [rdi], r8d → 44 89 07 + // REX.R=1, reg=000 → reg=8 (R8) + let bytes = [0x44, 0x89, 0x07]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.data, regs.r8 & 0xFFFF_FFFF); + assert_eq!(insn.len, 3); + } + + #[test] + fn test_mov_r10d_dword_ptr() { + // mov r10d, dword [rdi] → 44 8B 17 + // REX.R=1, reg=010 → reg=10 (R10) + let bytes = [0x44, 0x8B, 0x17]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(!insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.gpr_index, Some(10)); // R10 + assert_eq!(insn.len, 3); + } + + // --- Error cases --- + + #[test] + fn test_empty_bytes_error() { + let regs = make_regs(); + assert!(decode_mmio_insn(&[], ®s).is_err()); + } + + #[test] + fn test_unrecognized_opcode_error() { + let regs = make_regs(); + let bytes = [0xFF, 0x07]; // Not a MOV + assert!(decode_mmio_insn(&bytes, ®s).is_err()); + } + + // --- disp32 addressing --- + + #[test] + fn test_mov_dword_ptr_disp32() { + // mov dword [rdi+0x100], eax → 89 87 00 01 00 00 + // ModRM: mod=10, reg=000(eax), r/m=111(rdi) → disp32 + let bytes = [0x89, 0x87, 0x00, 0x01, 0x00, 0x00]; + let regs = make_regs(); + let insn = decode_mmio_insn(&bytes, ®s).unwrap(); + assert!(insn.is_write); + assert_eq!(insn.access_size, 4); + assert_eq!(insn.data, regs.rax & 0xFFFF_FFFF); + assert_eq!(insn.len, 6); + } + + // --- read_gpr coverage --- + + #[test] + fn test_read_gpr_all_registers() { + let regs = make_regs(); + assert_eq!(read_gpr(®s, 0), regs.rax); + assert_eq!(read_gpr(®s, 1), regs.rcx); + assert_eq!(read_gpr(®s, 2), regs.rdx); + assert_eq!(read_gpr(®s, 3), regs.rbx); + assert_eq!(read_gpr(®s, 4), regs.rsp); + assert_eq!(read_gpr(®s, 5), regs.rbp); + assert_eq!(read_gpr(®s, 6), regs.rsi); + assert_eq!(read_gpr(®s, 7), regs.rdi); + assert_eq!(read_gpr(®s, 8), regs.r8); + assert_eq!(read_gpr(®s, 9), regs.r9); + assert_eq!(read_gpr(®s, 10), regs.r10); + assert_eq!(read_gpr(®s, 11), regs.r11); + assert_eq!(read_gpr(®s, 12), regs.r12); + assert_eq!(read_gpr(®s, 13), regs.r13); + assert_eq!(read_gpr(®s, 14), regs.r14); + assert_eq!(read_gpr(®s, 15), regs.r15); + assert_eq!(read_gpr(®s, 16), 0); // Out of range + } +} diff --git a/src/vmm/src/windows/memory.rs b/src/vmm/src/windows/memory.rs new file mode 100644 index 000000000..3f682aced --- /dev/null +++ b/src/vmm/src/windows/memory.rs @@ -0,0 +1,391 @@ +//! Guest memory management for WHPX VMs. +//! +//! Handles allocation and mapping of guest physical memory. +//! On Windows, we use VirtualAlloc for host-side memory allocation +//! since the rust-vmm vm-memory crate doesn't support Windows. +//! +//! Memory layout constants are available on all platforms for cross-platform +//! testing of boot setup logic. + +// Guest physical memory layout constants for x86_64 Linux boot. +// These match the conventional Linux boot protocol addresses. + +/// Start of the zero page (boot_params structure). +pub const ZERO_PAGE_START: u64 = 0x7000; + +/// Start of the PML4 page table. +pub const PML4_START: u64 = 0x9000; + +/// Start of the PDPT page table. +pub const PDPT_START: u64 = 0xA000; + +/// Start of the PD page tables (4 entries for identity-mapping 4GB). +pub const PD_START: u64 = 0xB000; + +/// Kernel command line address. +pub const CMDLINE_START: u64 = 0x20000; + +/// Maximum kernel command line length. +pub const CMDLINE_MAX_SIZE: u64 = 0x10000; + +/// Kernel load address (1MB — standard bzImage load address). +pub const KERNEL_START: u64 = 0x100000; + +/// Offset of the 64-bit entry point (`startup_64`) from KERNEL_START. +pub const KERNEL_64BIT_ENTRY_OFFSET: u64 = 0x200; + +/// ACPI tables region. +pub const ACPI_START: u64 = 0xE0000; + +/// Initial stack pointer (below 1MB, above page tables). +pub const BOOT_STACK_POINTER: u64 = 0x8FF0; + +/// Virtio-MMIO base address (above guest RAM, below 4GB identity map). +pub const VIRTIO_MMIO_BASE: u64 = 0xD000_0000; + +/// Size of the MMIO region reserved for virtio devices. +/// 2MB provides room for many devices and aligns with 2MB page table granularity. +pub const MMIO_REGION_SIZE: u64 = 0x20_0000; + +/// IOAPIC MMIO base address. +pub const IOAPIC_MMIO_BASE: u64 = 0xFEC0_0000; + +/// IOAPIC MMIO region size (4 KB). +pub const IOAPIC_MMIO_SIZE: u64 = 0x1000; + +/// LAPIC MMIO base address. +pub const LAPIC_MMIO_BASE: u64 = 0xFEE0_0000; + +/// LAPIC MMIO region size (4 KB). +pub const LAPIC_MMIO_SIZE: u64 = 0x1000; + +// Windows-specific guest memory allocation and mapping. +#[cfg(target_os = "windows")] +mod imp { + use std::ptr; + + use windows_sys::Win32::System::Hypervisor::WHV_MAP_GPA_RANGE_FLAGS; + use windows_sys::Win32::System::Memory::{ + VirtualAlloc, VirtualFree, MEM_COMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_READWRITE, + }; + + use super::super::error::{Result, WkrunError}; + use super::super::whpx::WhpxPartition; + + /// A contiguous region of guest physical memory. + pub struct GuestMemoryRegion { + /// Host virtual address of the allocated memory. + host_addr: *mut u8, + /// Guest physical address this region maps to. + guest_addr: u64, + /// Size of the region in bytes. + size: u64, + } + + // SAFETY: The memory region is a simple allocation that can be sent between threads. + unsafe impl Send for GuestMemoryRegion {} + unsafe impl Sync for GuestMemoryRegion {} + + impl GuestMemoryRegion { + /// Allocate a new memory region using VirtualAlloc. + pub fn new(guest_addr: u64, size: u64) -> Result { + let host_addr = unsafe { + VirtualAlloc( + ptr::null(), + size as usize, + MEM_COMMIT | MEM_RESERVE, + PAGE_READWRITE, + ) + }; + + if host_addr.is_null() { + return Err(WkrunError::Memory(format!( + "VirtualAlloc failed for {} bytes at GPA 0x{:X}", + size, guest_addr + ))); + } + + Ok(GuestMemoryRegion { + host_addr: host_addr as *mut u8, + guest_addr, + size, + }) + } + + /// Get the host virtual address. + pub fn host_addr(&self) -> *mut u8 { + self.host_addr + } + + /// Get the guest physical address. + pub fn guest_addr(&self) -> u64 { + self.guest_addr + } + + /// Get the size of this region. + pub fn size(&self) -> u64 { + self.size + } + + /// Write data into guest memory at a guest physical address offset. + pub fn write_at(&self, offset: u64, data: &[u8]) -> Result<()> { + if offset + data.len() as u64 > self.size { + return Err(WkrunError::Memory(format!( + "write out of bounds: offset 0x{:X} + {} > region size 0x{:X}", + offset, + data.len(), + self.size + ))); + } + + // SAFETY: We verified the offset + len is within bounds. + unsafe { + let dst = self.host_addr.add(offset as usize); + ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len()); + } + Ok(()) + } + + /// Read data from guest memory at a guest physical address offset. + pub fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<()> { + if offset + buf.len() as u64 > self.size { + return Err(WkrunError::Memory(format!( + "read out of bounds: offset 0x{:X} + {} > region size 0x{:X}", + offset, + buf.len(), + self.size + ))); + } + + // SAFETY: We verified the offset + len is within bounds. + unsafe { + let src = self.host_addr.add(offset as usize); + ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), buf.len()); + } + Ok(()) + } + + /// Write a value at a specific offset. + pub fn write_obj(&self, offset: u64, val: &T) -> Result<()> { + let size = std::mem::size_of::() as u64; + if offset + size > self.size { + return Err(WkrunError::Memory(format!( + "write_obj out of bounds: offset 0x{:X} + {} > region size 0x{:X}", + offset, size, self.size + ))); + } + + // SAFETY: We verified bounds, and T is Copy (no drop needed). + unsafe { + let dst = self.host_addr.add(offset as usize) as *mut T; + ptr::write_unaligned(dst, *val); + } + Ok(()) + } + + /// Map this region into a WHPX partition's guest physical address space. + pub fn map_to_partition(&self, partition: &WhpxPartition) -> Result<()> { + // SAFETY: host_addr points to our VirtualAlloc'd memory which is valid + // for the lifetime of this GuestMemoryRegion. + unsafe { + partition.map_gpa_range( + self.host_addr, + self.guest_addr, + self.size, + // WHvMapGpaRangeFlagRead | WHvMapGpaRangeFlagWrite | WHvMapGpaRangeFlagExecute + 0x7 as WHV_MAP_GPA_RANGE_FLAGS, + ) + } + } + } + + impl Drop for GuestMemoryRegion { + fn drop(&mut self) { + if !self.host_addr.is_null() { + // SAFETY: We allocated this memory with VirtualAlloc. + unsafe { + VirtualFree(self.host_addr as *mut std::ffi::c_void, 0, MEM_RELEASE); + } + } + } + } + + /// Guest memory manager — holds all guest memory regions. + pub struct GuestMemory { + regions: Vec, + total_size: u64, + } + + impl GuestMemory { + /// Create guest memory, leaving holes for device MMIO regions. + /// + /// When guest RAM overlaps device MMIO addresses, the memory is split + /// into multiple regions with unmapped gaps so that WHPX generates MMIO + /// exits (instead of treating device accesses as RAM reads). + /// + /// Holes are created for: + /// - Virtio MMIO (0xD000_0000 .. 0xD020_0000) — virtio device registers + /// - APIC MMIO (0xFEC0_0000 .. 0xFEE0_1000) — IOAPIC + LAPIC registers + pub fn new(size_mib: u32) -> Result { + let size = (size_mib as u64) * 1024 * 1024; + + if size > super::VIRTIO_MMIO_BASE { + let mmio_base = super::VIRTIO_MMIO_BASE; + let mmio_end = mmio_base + super::MMIO_REGION_SIZE; + let region1 = GuestMemoryRegion::new(0, mmio_base)?; + + // Check if RAM extends into the APIC MMIO region. + // IOAPIC at 0xFEC0_0000 and LAPIC at 0xFEE0_0000 must be + // unmapped so WHPX generates MMIO exits for APIC accesses. + let apic_start = super::IOAPIC_MMIO_BASE; + let apic_end = super::LAPIC_MMIO_BASE + super::LAPIC_MMIO_SIZE; + + if size > apic_start { + // RAM extends past APIC region — 3 regions with 2 holes. + // Region 1: 0 .. VIRTIO_MMIO_BASE + // (hole): VIRTIO MMIO + // Region 2: VIRTIO_MMIO_END .. IOAPIC_MMIO_BASE + // (hole): APIC MMIO (IOAPIC + LAPIC) + // Region 3: APIC_END .. ram_end + let region2 = GuestMemoryRegion::new(mmio_end, apic_start - mmio_end)?; + let mut regions = vec![region1, region2]; + + if size > apic_end { + let region3 = GuestMemoryRegion::new(apic_end, size - apic_end)?; + regions.push(region3); + } + + Ok(GuestMemory { + regions, + total_size: size, + }) + } else { + // RAM between VIRTIO and APIC — 2 regions with 1 hole. + let region2 = GuestMemoryRegion::new(mmio_end, size - mmio_end)?; + Ok(GuestMemory { + regions: vec![region1, region2], + total_size: size, + }) + } + } else { + // RAM fits below MMIO — single contiguous region. + let region = GuestMemoryRegion::new(0, size)?; + Ok(GuestMemory { + regions: vec![region], + total_size: size, + }) + } + } + + /// Map all guest memory regions into a WHPX partition. + pub fn map_to_partition(&self, partition: &WhpxPartition) -> Result<()> { + for region in &self.regions { + region.map_to_partition(partition)?; + } + Ok(()) + } + + /// Write data at a guest physical address. + pub fn write_at_addr(&self, guest_addr: u64, data: &[u8]) -> Result<()> { + for region in &self.regions { + let region_end = region.guest_addr() + region.size(); + if guest_addr >= region.guest_addr() && guest_addr < region_end { + let offset = guest_addr - region.guest_addr(); + return region.write_at(offset, data); + } + } + Err(WkrunError::Memory(format!( + "no region contains GPA 0x{:X}", + guest_addr + ))) + } + + /// Read data from a guest physical address. + pub fn read_at_addr(&self, guest_addr: u64, buf: &mut [u8]) -> Result<()> { + for region in &self.regions { + let region_end = region.guest_addr() + region.size(); + if guest_addr >= region.guest_addr() && guest_addr < region_end { + let offset = guest_addr - region.guest_addr(); + return region.read_at(offset, buf); + } + } + Err(WkrunError::Memory(format!( + "no region contains GPA 0x{:X}", + guest_addr + ))) + } + + /// Write a typed value at a guest physical address. + pub fn write_obj_at_addr(&self, guest_addr: u64, val: &T) -> Result<()> { + for region in &self.regions { + let region_end = region.guest_addr() + region.size(); + if guest_addr >= region.guest_addr() && guest_addr < region_end { + let offset = guest_addr - region.guest_addr(); + return region.write_obj(offset, val); + } + } + Err(WkrunError::Memory(format!( + "no region contains GPA 0x{:X}", + guest_addr + ))) + } + + /// Get total guest memory size in bytes. + pub fn total_size(&self) -> u64 { + self.total_size + } + } +} + +#[cfg(target_os = "windows")] +pub use imp::*; + +#[cfg(test)] +mod tests { + use super::*; + + // Compile-time assertions for memory layout ordering. + const _: () = { + assert!(ZERO_PAGE_START < PML4_START); + assert!(PML4_START < PDPT_START); + assert!(PDPT_START < PD_START); + assert!(PD_START < CMDLINE_START); + assert!(CMDLINE_START < KERNEL_START); + assert!(ZERO_PAGE_START < BOOT_STACK_POINTER); + assert!(BOOT_STACK_POINTER < PML4_START); + }; + + #[test] + fn test_kernel_start_at_1mb() { + assert_eq!(KERNEL_START, 0x100000); + } + + #[test] + fn test_memory_layout_no_overlap() { + let regions = [ + ("zero_page", ZERO_PAGE_START, ZERO_PAGE_START + 0x1000), + ("pml4", PML4_START, PML4_START + 0x1000), + ("pdpt", PDPT_START, PDPT_START + 0x1000), + ("pd", PD_START, PD_START + 0x4000), + ("cmdline", CMDLINE_START, CMDLINE_START + CMDLINE_MAX_SIZE), + ("kernel", KERNEL_START, KERNEL_START + 0x1000), + ]; + + for i in 0..regions.len() { + for j in (i + 1)..regions.len() { + let (name_a, start_a, end_a) = regions[i]; + let (name_b, start_b, end_b) = regions[j]; + assert!( + end_a <= start_b || end_b <= start_a, + "regions {} and {} overlap: [{:#X}..{:#X}) vs [{:#X}..{:#X})", + name_a, + name_b, + start_a, + end_a, + start_b, + end_b + ); + } + } + } +} diff --git a/src/vmm/src/windows/mod.rs b/src/vmm/src/windows/mod.rs new file mode 100644 index 000000000..e8ef0abe3 --- /dev/null +++ b/src/vmm/src/windows/mod.rs @@ -0,0 +1,16 @@ +//! Virtual Machine Manager — hypervisor abstraction and WHPX backend. + +pub mod types; + +#[cfg(target_os = "windows")] +pub mod whpx; + +pub mod boot; +pub mod cmdline; +pub mod context; +pub mod devices; +pub mod error; +pub mod insn; +pub mod memory; +pub mod runner; +pub mod vcpu; diff --git a/src/vmm/src/windows/runner.rs b/src/vmm/src/windows/runner.rs new file mode 100644 index 000000000..40d3da972 --- /dev/null +++ b/src/vmm/src/windows/runner.rs @@ -0,0 +1,2296 @@ +//! VmRunner — full VM boot orchestration. +//! +//! Takes a configured VmContext, creates the WHPX partition and devices, +//! loads the kernel, and runs the vCPU loop until exit. +//! +//! Supports two modes: +//! - **Blocking**: `run()` — runs vCPU loop on the calling thread (used by `wkrun_start_enter`) +//! - **Async**: `start()` / `wait()` / `stop()` — spawns a background VM thread (used by BoxLite's Tokio runtime) + +#[cfg(target_os = "windows")] +mod imp { + use std::collections::HashMap; + use std::io::Write; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Condvar, Mutex}; + use std::time::{Duration, Instant}; + + use super::super::boot::loader::load_kernel_with_initrd; + use super::super::cmdline::build_kernel_cmdline; + use super::super::context::VmContext; + use super::super::devices::lapic::{IpiAction, LocalApic, SharedApicState}; + use super::super::devices::manager::{self as devices, DeviceManager}; + use super::super::devices::virtio::queue::GuestMemoryAccessor; + use super::super::error::{Result, WkrunError}; + use super::super::memory::{GuestMemory, LAPIC_MMIO_BASE, LAPIC_MMIO_SIZE}; + use super::super::types::VcpuExit; + use super::super::vcpu::VcpuRunConfig; + use super::super::whpx::{VcpuCanceller, WhpxPartition, WhpxVcpu}; + + /// Implement GuestMemoryAccessor directly on GuestMemory. + /// + /// This allows GuestMemory to be used via the GuestMemoryAccessor trait + /// in device handling code (virtio queues, block I/O, etc.). + impl GuestMemoryAccessor for GuestMemory { + fn read_at(&self, addr: u64, buf: &mut [u8]) -> Result<()> { + self.read_at_addr(addr, buf) + } + fn write_at(&self, addr: u64, data: &[u8]) -> Result<()> { + self.write_at_addr(addr, data) + } + } + + /// Maximum vCPU exits before giving up. + const MAX_EXITS: u64 = 500_000_000; + + /// Maximum consecutive HLT instructions before assuming shutdown. + /// + /// With ACPI tables, `poweroff` is detected instantly via PM1a_CNT. + /// MAX_HALTS is a safety fallback for non-ACPI shutdown paths. + /// At 1ms per tick, 50000 = ~50 second timeout. + /// + /// Must be high enough to tolerate normal guest idle periods (e.g. + /// waiting for gRPC data after boot). The guest HLTs in its idle + /// loop whenever there are no interrupts; this is normal and does + /// NOT indicate the VM is stuck. + const MAX_HALTS: u64 = 50_000; + + /// Number of spin-yield iterations before sleeping on HLT. + /// ~50µs of yielding to catch imminent timer interrupts. + const HLT_SPIN_ITERS: u32 = 50; + + /// Short sleep duration (µs) after spin phase completes without interrupt. + const HLT_SLEEP_US: u64 = 200; + + /// Per-AP (Application Processor) startup state. + /// + /// Each AP thread waits on its condvar until the BSP delivers an + /// INIT-SIPI-SIPI sequence via the LAPIC ICR register. + struct ApStartupState { + /// Whether this AP has received SIPI and should start executing. + started: Mutex, + /// Condvar to wake the AP thread when SIPI arrives. + condvar: Condvar, + /// SIPI vector — the AP starts executing at `vector * 0x1000`. + sipi_vector: Mutex>, + /// Whether INIT has been received (prerequisite for SIPI). + init_received: AtomicBool, + } + + impl ApStartupState { + fn new() -> Self { + Self { + started: Mutex::new(false), + condvar: Condvar::new(), + sipi_vector: Mutex::new(None), + init_received: AtomicBool::new(false), + } + } + } + + /// Handle for a running VM, stored in `RUNNING_VMS`. + struct VmHandle { + thread: Option>>, + run_config: VcpuRunConfig, + canceller: Arc>>, + } + + /// Registry of running VMs. A ctx_id appears here after `start()` and is + /// removed by `wait()`. + static RUNNING_VMS: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| Mutex::new(HashMap::new())); + + /// Translate a guest virtual address (GVA) to guest physical address (GPA) + /// by walking the x86_64 4-level page table starting from CR3. + #[allow(dead_code)] + fn translate_gva(guest_mem: &GuestMemory, cr3: u64, gva: u64) -> Option { + let pml4_base = cr3 & !0xFFF; + let pml4_idx = ((gva >> 39) & 0x1FF) as usize; + let pdpt_idx = ((gva >> 30) & 0x1FF) as usize; + let pd_idx = ((gva >> 21) & 0x1FF) as usize; + let pt_idx = ((gva >> 12) & 0x1FF) as usize; + let offset = gva & 0xFFF; + + // PML4 entry + let mut buf = [0u8; 8]; + guest_mem + .read_at_addr(pml4_base + (pml4_idx as u64) * 8, &mut buf) + .ok()?; + let pml4e = u64::from_le_bytes(buf); + if pml4e & 1 == 0 { + return None; + } // not present + + // PDPT entry + let pdpt_base = pml4e & 0x000F_FFFF_FFFF_F000; + guest_mem + .read_at_addr(pdpt_base + (pdpt_idx as u64) * 8, &mut buf) + .ok()?; + let pdpte = u64::from_le_bytes(buf); + if pdpte & 1 == 0 { + return None; + } + if pdpte & 0x80 != 0 { + // 1GB page + return Some((pdpte & 0x000F_FFFF_C000_0000) | (gva & 0x3FFF_FFFF)); + } + + // PD entry + let pd_base = pdpte & 0x000F_FFFF_FFFF_F000; + guest_mem + .read_at_addr(pd_base + (pd_idx as u64) * 8, &mut buf) + .ok()?; + let pde = u64::from_le_bytes(buf); + if pde & 1 == 0 { + return None; + } + if pde & 0x80 != 0 { + // 2MB page + return Some((pde & 0x000F_FFFF_FFE0_0000) | (gva & 0x1F_FFFF)); + } + + // PT entry + let pt_base = pde & 0x000F_FFFF_FFFF_F000; + guest_mem + .read_at_addr(pt_base + (pt_idx as u64) * 8, &mut buf) + .ok()?; + let pte = u64::from_le_bytes(buf); + if pte & 1 == 0 { + return None; + } + Some((pte & 0x000F_FFFF_FFFF_F000) | offset) + } + + /// Core vCPU loop shared by `run()` and `start()`. + /// + /// Sets up the WHPX partition, loads the kernel, creates devices and vCPU, + /// then runs the vCPU loop. The `run_config` controls when the loop stops, + /// and the vCPU's canceller is stored in `canceller_slot` so that `stop()` + /// can wake the vCPU. + fn run_vcpu_loop( + ctx: VmContext, + run_config: VcpuRunConfig, + canceller_slot: Arc>>, + ) -> Result { + // Open a diagnostic log file for debugging boot failures. + // Uses TEMP directory so it works on any Windows machine. + let mut diag_log: Option = None; + let diag_path = format!( + "{}\\whpx-diag.log", + std::env::var("TEMP").unwrap_or_else(|_| r"C:\Temp".to_string()) + ); + if let Ok(f) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&diag_path) + { + diag_log = Some(f); + } + + macro_rules! diag { + ($($arg:tt)*) => { + if let Some(ref mut f) = diag_log { + let _ = writeln!(f, $($arg)*); + let _ = f.flush(); + } + }; + } + diag!("\n=== VM START ctx_id={} ===", ctx.id); + + // Validate required fields. + let kernel_path = ctx + .kernel_path + .as_ref() + .ok_or_else(|| WkrunError::Config("kernel_path is required for VM start".into()))?; + + // Read kernel image. + let kernel_image = std::fs::read(kernel_path).map_err(|e| { + WkrunError::Boot(format!( + "failed to read kernel '{}': {}", + kernel_path.display(), + e + )) + })?; + + // Read initrd if provided. + let initrd_data = match ctx.initramfs_path { + Some(ref path) => Some(std::fs::read(path).map_err(|e| { + WkrunError::Boot(format!("failed to read initrd '{}': {}", path.display(), e)) + })?), + None => None, + }; + + // Check WHPX availability. + if !WhpxPartition::is_available()? { + return Err(WkrunError::WhpxUnavailable( + "WHPX is not available on this system".into(), + )); + } + + // Create partition. + let partition = WhpxPartition::new()?; + partition.set_processor_count(ctx.num_vcpus as u32)?; + partition.set_extended_vm_exits(true, true)?; + + // NOTE: Do NOT enable APIC emulation here. On Win10 MBP 2014, + // set_local_apic_emulation(true) returns success but then the APIC + // doesn't function — no interrupts get delivered and the kernel hangs + // before producing any console output. Software PIC is required. + + partition.setup()?; + + // Allocate and map guest memory. + let guest_mem = Arc::new(GuestMemory::new(ctx.ram_mib)?); + guest_mem.map_to_partition(&partition)?; + + // Create devices from context. + let ctx_id = ctx.id; + let setup = DeviceManager::from_context(&ctx)?; + devices::store_console_buffer(ctx_id, setup.console_buffer); + let devices = setup.devices; + + // NOTE: Block I/O workers are started lazily (deferred start) inside + // the vCPU loop, on the first MMIO write. Starting them here (before + // the vCPU runs) causes ~80% boot failure on WHPX — the worker + // thread creation appears to interfere with WHPX partition state + // during early boot. + + // Build kernel command line. + let cmdline = build_kernel_cmdline( + ctx.kernel_cmdline.as_deref(), + setup.has_root_disk, + &setup.mmio_slots, + ctx.root_disk_device.as_deref(), + ctx.root_disk_fstype.as_deref(), + ctx.exec_path.as_deref(), + &ctx.argv, + ctx.verbose, + ); + + // Load kernel. + let initrd_ref = initrd_data.as_deref(); + let (regs, sregs) = load_kernel_with_initrd( + &guest_mem, + &kernel_image, + &cmdline, + ctx.ram_mib, + initrd_ref, + ctx.num_vcpus, + )?; + + log::info!( + "Kernel loaded at 0x100000, RIP=0x{:X}, cmdline: {}", + regs.rip, + cmdline + ); + diag!("Kernel loaded, RIP={:#X}, ram={}MB", regs.rip, ctx.ram_mib); + + // Create all vCPUs. BSP (index 0) gets the boot registers. + // APs (index 1..N-1) are created but start in "wait for SIPI" state. + let num_vcpus = ctx.num_vcpus; + let mut vcpus = Vec::with_capacity(num_vcpus as usize); + for i in 0..num_vcpus as u32 { + let vcpu = WhpxVcpu::new(&partition, i)?; + if i == 0 { + vcpu.set_registers(®s)?; + vcpu.set_special_registers(&sregs)?; + } + vcpus.push(vcpu); + } + + // Collect cancellers for all vCPUs. The timer thread and stop() + // need to be able to wake any vCPU. + let cancellers: Vec = vcpus.iter().map(|v| v.canceller()).collect(); + + // Track which vCPUs have actually entered WHvRunVirtualProcessor. + // The timer thread only cancels running vCPUs — cancelling a VP that + // hasn't been run yet has undefined behavior on WHPX and may corrupt + // partition state (suspected cause of 4-vCPU BSP hang). + let vcpu_running: Vec> = (0..num_vcpus as usize) + .map(|_| Arc::new(AtomicBool::new(false))) + .collect(); + + // Store BSP canceller so stop() can wake the VM. + *canceller_slot.lock().unwrap() = Some(cancellers[0].clone()); + + // Create per-AP startup state (one per AP, indexed by ap_id - 1). + let ap_states: Vec = + (1..num_vcpus).map(|_| ApStartupState::new()).collect(); + + // Shared VM shutdown flag — set by any vCPU to signal all others to exit. + let shutdown = Arc::new(AtomicBool::new(false)); + + // Extract per-vCPU LAPIC refs BEFORE wrapping in Arc>. + // These allow the runner fast path to bypass the DeviceManager lock + // for LAPIC MMIO reads/writes, eliminating cross-vCPU contention. + let lapic_refs: Vec>> = devices.get_lapic_refs(); + + // Extract per-vCPU shared APIC states for lock-free cross-vCPU interrupt delivery. + // Source vCPUs atomic-OR vector bits here; owning vCPU pulls into local IRR. + let shared_states: Vec> = devices.get_shared_states(); + + let devices = Arc::new(Mutex::new(devices)); + + // Move diag_log into shared state for BSP diagnostics. + let diag_log = Arc::new(Mutex::new(diag_log)); + + log::info!("Starting VM with {} vCPU(s), ctx_id={}", num_vcpus, ctx_id); + + let mut exit_code = 1i32; + + // Use thread::scope so all vCPU threads are guaranteed to terminate + // before we clean up resources. The BSP runs in the scoped block; + // APs are spawned as scoped threads. + { + let shutdown_ref = &shutdown; + let devices_ref = &devices; + let ap_states_ref = &ap_states; + let cancellers_ref = &cancellers; + let run_config_ref = &run_config; + let guest_mem_ref: &GuestMemory = &guest_mem; + let diag_ref = &diag_log; + let lapic_refs_ref = &lapic_refs; + let shared_states_ref = &shared_states; + let vcpu_running_ref = &vcpu_running; + + // Spawn timer thread — cancels only RUNNING vCPUs every 1ms. + // Previously this cancelled ALL vCPUs including APs that hadn't + // entered WHvRunVirtualProcessor yet. At 4 vCPUs, calling + // WHvCancelRunVirtualProcessor on 3 non-running APs every 1ms + // may corrupt WHPX partition state, causing BSP hang. + let timer_flag = run_config.running.clone(); + let timer_cancellers: Vec = cancellers.clone(); + let timer_shutdown = shutdown.clone(); + let timer_vcpu_running: Vec> = vcpu_running.clone(); + let timer_thread = std::thread::spawn(move || { + while timer_flag.load(Ordering::Relaxed) && !timer_shutdown.load(Ordering::Relaxed) + { + std::thread::sleep(Duration::from_millis(1)); + for (i, c) in timer_cancellers.iter().enumerate() { + if timer_vcpu_running[i].load(Ordering::Relaxed) { + let _ = c.cancel(); + } + } + } + }); + + std::thread::scope(|s| { + // Spawn AP threads (vCPU 1..N-1). + for ap_idx in 1..num_vcpus as usize { + let vcpu = &vcpus[ap_idx]; + let my_lapic = &lapic_refs_ref[ap_idx]; + let my_shared = &shared_states_ref[ap_idx]; + let my_running = &vcpu_running_ref[ap_idx]; + s.spawn(move || { + run_ap_loop( + ap_idx as u8, + num_vcpus, + vcpu, + devices_ref, + guest_mem_ref, + shutdown_ref, + run_config_ref, + cancellers_ref, + &ap_states_ref[ap_idx - 1], + ctx_id, + diag_ref, + my_lapic, + my_shared, + shared_states_ref, + my_running, + ); + }); + } + + // BSP runs on the current thread. + // Mark BSP as running before entering the loop so timer can cancel it. + vcpu_running_ref[0].store(true, Ordering::Release); + let bsp_vcpu = &vcpus[0]; + let bsp_code = run_bsp_loop( + bsp_vcpu, + devices_ref, + guest_mem_ref, + shutdown_ref, + run_config_ref, + cancellers_ref, + ap_states_ref, + ctx_id, + diag_ref, + num_vcpus, + &lapic_refs_ref[0], + &shared_states_ref[0], + shared_states_ref, + ); + // BSP exited — signal all APs to exit. + shutdown_ref.store(true, Ordering::Release); + for c in cancellers_ref { + let _ = c.cancel(); + } + // Wake any APs still waiting for SIPI. + for ap in ap_states_ref { + *ap.started.lock().unwrap() = true; + ap.condvar.notify_one(); + } + exit_code = bsp_code; + }); + + // Stop the timer thread and block I/O workers. + run_config.request_stop(); + shutdown.store(true, Ordering::Release); + devices.lock().unwrap().stop_blk_workers(); + let _ = timer_thread.join(); + } + + log::info!("VM exited with code {}", exit_code); + + // Clean up diagnostic file on normal exit. + // Drop the file handle first, then remove the temp file. + drop(diag_log); + let _ = std::fs::remove_file(&diag_path); + + Ok(exit_code) + } + + /// Per-vCPU statistics counters. + struct VcpuStats { + exit_count: u64, + halt_count: u64, + total_halt_exits: u64, + halt_with_irq: u64, + mmio_count: u64, + serial_out_count: u64, + io_out_count: u64, + io_in_count: u64, + inject_count: u64, + cancelled_count: u64, + cpuid_count: u64, + last_progress: Instant, + start_time: Instant, + window_requested: bool, + /// Last MMIO read address (for tight loop detection). + last_mmio_read_addr: u64, + /// Consecutive MMIO reads to the same address. + consecutive_mmio_reads: u64, + } + + impl VcpuStats { + fn new() -> Self { + let now = Instant::now(); + Self { + exit_count: 0, + halt_count: 0, + total_halt_exits: 0, + halt_with_irq: 0, + mmio_count: 0, + serial_out_count: 0, + io_out_count: 0, + io_in_count: 0, + inject_count: 0, + cancelled_count: 0, + cpuid_count: 0, + last_progress: now, + start_time: now, + window_requested: false, + last_mmio_read_addr: u64::MAX, + consecutive_mmio_reads: 0, + } + } + } + + /// Try to inject a pending interrupt into a vCPU. + /// + /// Returns the number of interrupts injected (0 or 1). + fn try_inject_interrupt( + vcpu: &WhpxVcpu, + vcpu_id: u8, + devices: &mut DeviceManager, + stats: &mut VcpuStats, + ) -> Result<()> { + if !devices.irq_chip.has_pending(vcpu_id) { + return Ok(()); + } + + let already_pending = vcpu.has_pending_interruption().unwrap_or(false); + if already_pending { + return Ok(()); + } + + match vcpu.interrupts_enabled() { + Ok(true) => { + if let Some(vector) = devices.irq_chip.acknowledge(vcpu_id) { + log::debug!("vCPU{}: injecting interrupt vector {:#X}", vcpu_id, vector); + vcpu.inject_interrupt(vector)?; + devices.irq_chip.notify_injected(vcpu_id, vector); + stats.window_requested = false; + stats.inject_count += 1; + } + } + Ok(false) => { + if !stats.window_requested { + vcpu.request_interrupt_window()?; + stats.window_requested = true; + } + } + Err(ref e) => { + log::warn!("vCPU{}: interrupts_enabled() error: {:?}", vcpu_id, e); + } + } + Ok(()) + } + + /// Lock-free interrupt injection for APIC mode — no DeviceManager lock. + /// + /// Checks the per-vCPU LAPIC for injectable vectors and injects directly + /// via the WHPX vCPU API. Only the owning vCPU's LAPIC mutex is acquired. + fn try_inject_interrupt_fast( + vcpu: &WhpxVcpu, + lapic: &Mutex, + stats: &mut VcpuStats, + ) -> Result<()> { + let has_injectable = lapic.lock().unwrap().get_highest_injectable().is_some(); + if !has_injectable { + return Ok(()); + } + + let already_pending = vcpu.has_pending_interruption().unwrap_or(false); + if already_pending { + return Ok(()); + } + + match vcpu.interrupts_enabled() { + Ok(true) => { + let mut guard = lapic.lock().unwrap(); + if let Some(vector) = guard.get_highest_injectable() { + guard.start_of_interrupt(vector); + drop(guard); + vcpu.inject_interrupt(vector)?; + stats.inject_count += 1; + stats.window_requested = false; + } + } + Ok(false) => { + if !stats.window_requested { + vcpu.request_interrupt_window()?; + stats.window_requested = true; + } + } + Err(ref e) => { + log::warn!( + "try_inject_interrupt_fast: interrupts_enabled() error: {:?}", + e + ); + } + } + Ok(()) + } + + /// Dispatch an IPI action from a LAPIC ICR write. + fn dispatch_ipi( + action: IpiAction, + devices: &mut DeviceManager, + ap_states: &[ApStartupState], + cancellers: &[VcpuCanceller], + diag_log: &Arc>>, + start_time: Instant, + ) { + macro_rules! ipi_diag { + ($($arg:tt)*) => { + if let Ok(mut guard) = diag_log.lock() { + if let Some(ref mut f) = *guard { + let _ = write!(f, "[{:.3}s] ", start_time.elapsed().as_secs_f64()); + let _ = writeln!(f, $($arg)*); + let _ = f.flush(); + } + } + }; + } + match action { + IpiAction::None => {} + IpiAction::SendInit { target_apic_id } => { + let ap_idx = target_apic_id as usize; + if ap_idx > 0 && ap_idx - 1 < ap_states.len() { + ap_states[ap_idx - 1] + .init_received + .store(true, Ordering::Release); + ipi_diag!("IPI: INIT delivered to AP{}", target_apic_id); + } else { + ipi_diag!( + "IPI: INIT target AP{} out of range (max={})", + target_apic_id, + ap_states.len() + ); + } + } + IpiAction::SendSipi { + target_apic_id, + vector, + } => { + let ap_idx = target_apic_id as usize; + if ap_idx > 0 && ap_idx - 1 < ap_states.len() { + let state = &ap_states[ap_idx - 1]; + if state.init_received.load(Ordering::Acquire) { + *state.sipi_vector.lock().unwrap() = Some(vector); + *state.started.lock().unwrap() = true; + state.condvar.notify_one(); + ipi_diag!( + "IPI: SIPI delivered to AP{}, vector={:#X}, start_addr={:#X}", + target_apic_id, + vector, + (vector as u64) * 0x1000 + ); + } else { + ipi_diag!( + "IPI: SIPI to AP{} IGNORED (no INIT received)", + target_apic_id, + ); + } + } + } + IpiAction::SendInterrupt { + target_apic_id, + vector, + } => { + devices + .irq_chip + .deliver_ipi_interrupt(target_apic_id, vector); + let idx = target_apic_id as usize; + if idx < cancellers.len() { + let _ = cancellers[idx].cancel(); + } + ipi_diag!( + "IPI: interrupt vector={:#X} → vCPU{}", + vector, + target_apic_id, + ); + } + IpiAction::BroadcastInterrupt { + source_apic_id, + vector, + } => { + // Send to all vCPUs except the source. + let num = cancellers.len(); + for idx in 0..num { + if idx as u8 != source_apic_id { + devices.irq_chip.deliver_ipi_interrupt(idx as u8, vector); + let _ = cancellers[idx].cancel(); + } + } + ipi_diag!( + "IPI: broadcast vector={:#X} from vCPU{} → all-excl-self ({} targets)", + vector, + source_apic_id, + num - 1, + ); + } + } + } + + /// Fast path: read from LAPIC MMIO without acquiring DeviceManager lock. + /// + /// All LAPIC register reads are safe to handle via the per-vCPU LAPIC lock + /// since they only access the vCPU's own LAPIC state (IRR, ISR, TPR, CCR, etc). + /// Returns Some(value) if the address is in the LAPIC range, None otherwise. + fn handle_lapic_mmio_read_fast(lapic: &Mutex, address: u64) -> Option { + if address >= LAPIC_MMIO_BASE && address < LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE { + let offset = address - LAPIC_MMIO_BASE; + Some(lapic.lock().unwrap().read_mmio(offset) as u64) + } else { + None + } + } + + /// Result of a fast-path LAPIC MMIO write. + enum LapicWriteFastResult { + /// Write handled completely (no further action needed). + Handled, + /// ICR Low write: IPI action needs dispatching inline (lock-free). + IpiAction(IpiAction), + /// Not handled: needs DeviceManager slow path (EOI, SVR, non-LAPIC). + NotHandled, + } + + /// Fast path: write to LAPIC MMIO without acquiring DeviceManager lock. + /// + /// Handles most LAPIC registers directly via per-vCPU lock. ICR Low writes + /// are parsed and returned as IpiAction for inline dispatch (lock-free IPI). + /// Only EOI (→ IOAPIC propagation) and SVR (→ APIC transition check) + /// require the DeviceManager slow path. + fn handle_lapic_mmio_write_fast( + lapic: &Mutex, + address: u64, + data: u64, + ) -> LapicWriteFastResult { + if address >= LAPIC_MMIO_BASE && address < LAPIC_MMIO_BASE + LAPIC_MMIO_SIZE { + let offset = address - LAPIC_MMIO_BASE; + match offset { + // TPR: no cross-device side effects. + 0x080 | + // ICR High: sets destination, no IPI dispatch until ICR Low write. + 0x310 | + // LVT Timer: configures timer mode/vector, no immediate side effects. + 0x320 | + // Timer Initial Count: starts/resets timer countdown. + 0x380 | + // Timer Divide Config: sets timer divider. + 0x3E0 => { + lapic.lock().unwrap().write_mmio(offset, data as u32); + LapicWriteFastResult::Handled + } + // ICR Low (0x300): parse ICR and return IPI action for inline dispatch. + // This eliminates the DeviceManager lock for ALL IPI dispatch. + 0x300 => { + let result = lapic.lock().unwrap().write_mmio(offset, data as u32); + LapicWriteFastResult::IpiAction(result.ipi_action) + } + // EOI (0x0B0): needs IOAPIC propagation → DeviceManager path. + // SVR (0x0F0): needs APIC transition check → DeviceManager path. + _ => LapicWriteFastResult::NotHandled, + } + } else { + LapicWriteFastResult::NotHandled + } + } + + /// BSP (Bootstrap Processor, vCPU 0) main loop. + /// + /// Handles timer ticking, device polling, interrupt injection, block worker + /// start, IPI dispatch, and progress diagnostics. + #[allow(clippy::too_many_arguments)] + fn run_bsp_loop( + vcpu: &WhpxVcpu, + devices: &Arc>, + guest_mem: &GuestMemory, + shutdown: &AtomicBool, + run_config: &VcpuRunConfig, + cancellers: &[VcpuCanceller], + ap_states: &[ApStartupState], + ctx_id: u32, + diag_log: &Arc>>, + num_vcpus: u8, + my_lapic: &Arc>, + my_shared: &Arc, + all_shared: &[Arc], + ) -> i32 { + macro_rules! diag { + ($($arg:tt)*) => { + if let Ok(mut guard) = diag_log.lock() { + if let Some(ref mut f) = *guard { + let _ = writeln!(f, $($arg)*); + let _ = f.flush(); + } + } + }; + } + + let mut stats = VcpuStats::new(); + let mut blk_workers_started = false; + let sync_block = std::env::var("BOXLITE_SYNC_BLOCK").is_ok(); + let mut _last_exit_reason = "none"; + + loop { + if shutdown.load(Ordering::Relaxed) || !run_config.should_run() { + _last_exit_reason = "SHUTDOWN_SIGNAL"; + return 0; + } + + // 1. Pull remote interrupts (lock-free) + tick own LAPIC timer. + { + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + if let Some(vector) = lapic.tick_timer(Instant::now()) { + lapic.accept_interrupt(vector); + } + } + // 2. Lock-free interrupt injection (APIC mode, per-vCPU only). + if let Err(e) = try_inject_interrupt_fast(vcpu, my_lapic, &mut stats) { + log::error!("BSP lock-free inject error: {:?}", e); + } + // 3. Tick PIT + poll devices (reduced lock time — no LAPIC timer loop). + { + let mut dm = devices.lock().unwrap(); + dm.tick_and_poll(0, guest_mem); + // PIC mode fallback: inject via DeviceManager path. + if !dm.irq_chip.apic_mode() { + if let Err(e) = try_inject_interrupt(vcpu, 0, &mut dm, &mut stats) { + log::error!("BSP PIC inject error: {:?}", e); + } + } + } + // 4. Pull any interrupts raised by tick_and_poll (device completions, + // PIT timer) and inject before entering the guest. + { + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + } + if let Err(e) = try_inject_interrupt_fast(vcpu, my_lapic, &mut stats) { + log::error!("BSP post-poll inject error: {:?}", e); + } + + let exit = match vcpu.run() { + Ok(exit) => exit, + Err(e) => { + log::error!( + "BSP vcpu.run() FAILED after {} exits: {:?}", + stats.exit_count, + e + ); + return 1; + } + }; + stats.exit_count += 1; + + // Periodic progress logging for early-boot diagnostics. + // Written to diag file (not log::info!) because shim's tracing + // subscriber doesn't capture log crate output from VMM. + if stats.exit_count % 50_000 == 0 { + let console_len = devices::get_console_output(ctx_id) + .map(|b| b.len()) + .unwrap_or(0); + diag!( + "BSP progress: exit={} serial={} mmio={} io_out={} io_in={} inj={} hlt={} console={}B elapsed={:.1}s", + stats.exit_count, stats.serial_out_count, stats.mmio_count, + stats.io_out_count, stats.io_in_count, stats.inject_count, + stats.total_halt_exits, console_len, + stats.start_time.elapsed().as_secs_f64(), + ); + } + + match exit { + VcpuExit::IoOut { port, size, data } => { + stats.halt_count = 0; + stats.io_out_count += 1; + if port == 0x3F8 { + stats.serial_out_count += 1; + } + let mut dm = devices.lock().unwrap(); + dm.handle_io_out(port, size, data); + if dm.shutdown_requested() { + log::info!("ACPI shutdown detected after {} exits", stats.exit_count); + _last_exit_reason = "ACPI_SHUTDOWN"; + return 0; + } + drop(dm); + if let Err(e) = vcpu.skip_instruction() { + log::error!("BSP skip_instruction error: {:?}", e); + return 1; + } + } + VcpuExit::IoIn { port, size } => { + stats.halt_count = 0; + stats.io_in_count += 1; + let data = devices.lock().unwrap().handle_io_in(port, size); + if let Err(e) = vcpu.complete_io_in(data, size) { + log::error!("BSP complete_io_in error: {:?}", e); + return 1; + } + } + VcpuExit::MmioRead { address, size } => { + stats.halt_count = 0; + stats.mmio_count += 1; + // Detect tight MMIO read loops (same address read 10K+ times). + // This catches BSP hang during LAPIC timer calibration. + if address == stats.last_mmio_read_addr { + stats.consecutive_mmio_reads += 1; + if stats.consecutive_mmio_reads == 10_000 { + diag!( + "BSP: tight MMIO read loop: addr={:#X} count={} exit={}", + address, + stats.consecutive_mmio_reads, + stats.exit_count + ); + if let Ok(regs) = vcpu.get_registers() { + diag!("BSP: RIP={:#X} at tight MMIO loop", regs.rip); + } + } + } else { + stats.last_mmio_read_addr = address; + stats.consecutive_mmio_reads = 1; + } + // Fast path: LAPIC reads bypass DeviceManager lock. + let data = if let Some(val) = handle_lapic_mmio_read_fast(my_lapic, address) { + val + } else { + devices.lock().unwrap().handle_mmio_read(0, address, size) + }; + if let Err(e) = vcpu.complete_mmio_read(data) { + log::error!("BSP complete_mmio_read error: {:?}", e); + return 1; + } + } + VcpuExit::MmioWrite { + address, + size, + data, + } => { + stats.halt_count = 0; + stats.mmio_count += 1; + match handle_lapic_mmio_write_fast(my_lapic, address, data) { + LapicWriteFastResult::Handled => { + // Fast path: handled via per-vCPU lock only. + } + LapicWriteFastResult::IpiAction(action) => { + // ICR fast path: dispatch IPI inline (lock-free). + match action { + IpiAction::SendInterrupt { + target_apic_id, + vector, + } => { + let idx = target_apic_id as usize; + if idx < all_shared.len() { + all_shared[idx].request_interrupt(vector); + if idx < cancellers.len() { + let _ = cancellers[idx].cancel(); + } + } + } + IpiAction::BroadcastInterrupt { + source_apic_id, + vector, + } => { + // Broadcast to all vCPUs except source (lock-free). + for idx in 0..all_shared.len() { + if idx as u8 != source_apic_id { + all_shared[idx].request_interrupt(vector); + if idx < cancellers.len() { + let _ = cancellers[idx].cancel(); + } + } + } + } + IpiAction::SendInit { .. } | IpiAction::SendSipi { .. } => { + // INIT/SIPI use existing condvar mechanism. + dispatch_ipi( + action, + &mut devices.lock().unwrap(), + ap_states, + cancellers, + diag_log, + stats.start_time, + ); + } + IpiAction::None => {} + } + } + LapicWriteFastResult::NotHandled => { + // Slow path: needs DeviceManager for EOI/SVR or non-LAPIC devices. + let mut dm = devices.lock().unwrap(); + if !blk_workers_started && !sync_block { + dm.start_blk_workers(); + blk_workers_started = true; + log::info!( + target: "whpx::diag", + "Block workers started at exit={} mmio={} elapsed={:.1}ms", + stats.exit_count, + stats.mmio_count, + stats.start_time.elapsed().as_secs_f64() * 1000.0 + ); + } + let ipi_action = + dm.handle_mmio_write(0, address, size, data, guest_mem); + if !matches!(ipi_action, IpiAction::None) { + dispatch_ipi( + ipi_action, + &mut dm, + ap_states, + cancellers, + diag_log, + stats.start_time, + ); + } + drop(dm); + } + } + if let Err(e) = vcpu.skip_instruction() { + log::error!("BSP skip_instruction error: {:?}", e); + return 1; + } + } + VcpuExit::InterruptWindow => { + stats.halt_count = 0; + stats.window_requested = false; + } + VcpuExit::Halt => { + stats.total_halt_exits += 1; + if !run_config.should_run() || shutdown.load(Ordering::Relaxed) { + log::info!("BSP: stop requested, exiting on Halt"); + return 0; + } + + // Pull remote interrupts + tick timer before sleeping. + { + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + if let Some(vector) = lapic.tick_timer(Instant::now()) { + lapic.accept_interrupt(vector); + } + if lapic.get_highest_injectable().is_some() { + drop(lapic); + if let Err(e) = try_inject_interrupt_fast(vcpu, my_lapic, &mut stats) { + log::error!("BSP HLT inject error: {:?}", e); + } + stats.halt_with_irq += 1; + stats.halt_count = 0; + continue; + } + } + // Also poll PIT + devices (PIC mode fallback + block I/O). + { + let mut dm = devices.lock().unwrap(); + dm.tick_and_poll(0, guest_mem); + if !dm.irq_chip.apic_mode() && dm.irq_chip.has_pending(0) { + let already_pending = vcpu.has_pending_interruption().unwrap_or(false); + if !already_pending { + if let Some(vector) = dm.irq_chip.acknowledge(0) { + let _ = vcpu.inject_interrupt(vector); + dm.irq_chip.notify_injected(0, vector); + stats.window_requested = false; + stats.inject_count += 1; + } + } + stats.halt_with_irq += 1; + stats.halt_count = 0; + continue; + } + } + + stats.halt_count += 1; + + if stats.halt_count % 1000 == 0 { + if let Ok(regs) = vcpu.get_registers() { + let console_len = devices::get_console_output(ctx_id) + .map(|b| b.len()) + .unwrap_or(0); + let if_flag = vcpu.interrupts_enabled().unwrap_or(false); + diag!( + "BSP HLT stuck: consecutive={} total_halt={} halt_w_irq={} \ + exits={} RIP={:#X} IF={} console={}B mmio={} vcpus={}", + stats.halt_count, + stats.total_halt_exits, + stats.halt_with_irq, + stats.exit_count, + regs.rip, + if_flag, + console_len, + stats.mmio_count, + num_vcpus + ); + } + } + + if stats.halt_count > MAX_HALTS { + log::warn!( + "BSP halted {} times consecutively after {} exits", + stats.halt_count, + stats.exit_count, + ); + _last_exit_reason = "HALT_MAX_REACHED"; + return 0; + } + + // Tiered sleep: spin-yield phase to catch imminent interrupts, + // then short sleep if no interrupt arrived. + let mut woke_by_irq = false; + for i in 0..HLT_SPIN_ITERS { + std::thread::yield_now(); + if i % 10 == 9 { + // Fast check: pull_irr + per-LAPIC (no DeviceManager lock). + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + if lapic.get_highest_injectable().is_some() { + woke_by_irq = true; + break; + } + drop(lapic); + // Slow check: tick PIT + poll devices. + let mut dm = devices.lock().unwrap(); + dm.tick_and_poll(0, guest_mem); + } + } + if !woke_by_irq { + std::thread::sleep(Duration::from_micros(HLT_SLEEP_US)); + } + } + VcpuExit::Shutdown => { + log::info!("BSP: VM shutdown after {} exits", stats.exit_count); + return 0; + } + VcpuExit::Cancelled => { + if !run_config.should_run() || shutdown.load(Ordering::Relaxed) { + log::info!("BSP: stop requested on Cancelled"); + return 0; + } + if stats.last_progress.elapsed() >= Duration::from_secs(2) { + stats.last_progress = Instant::now(); + if let Ok(regs) = vcpu.get_registers() { + let dm = devices.lock().unwrap(); + let console_len = devices::get_console_output(ctx_id) + .map(|b| b.len()) + .unwrap_or(0); + let (qn, bc) = dm.blk_stats(); + let apic_mode = dm.irq_chip.apic_mode(); + let blk_mode = if sync_block { + "sync" + } else if blk_workers_started { + "async" + } else { + "pending" + }; + drop(dm); + diag!( + "vCPU0 @ {:.1}s: exits={} RIP={:#X} console={}B mmio={} halt={}/{} inj={} blk_comp={} mode={}/{} io_out={} serial={} blk_qn={} vcpus={}", + stats.start_time.elapsed().as_secs_f64(), + stats.exit_count, regs.rip, console_len, + stats.mmio_count, stats.halt_count, stats.total_halt_exits, + stats.inject_count, bc, + if apic_mode { "apic" } else { "pic" }, blk_mode, + stats.io_out_count, stats.serial_out_count, qn, num_vcpus, + ); + } + } + } + VcpuExit::MsrAccess { + msr_number, + is_write, + rax, + rdx, + } => { + stats.halt_count = 0; + if is_write { + log::trace!( + "BSP: MSR write 0x{:08X} <- 0x{:016X}", + msr_number, + (rdx << 32) | (rax & 0xFFFF_FFFF) + ); + if let Err(e) = vcpu.skip_instruction() { + log::error!("BSP skip_instruction error: {:?}", e); + return 1; + } + } else { + let value = super::handle_msr_read(0, msr_number); + log::trace!("BSP: MSR read 0x{:08X} -> 0x{:X}", msr_number, value); + if let Err(e) = vcpu.complete_msr_read(value) { + log::error!("BSP complete_msr_read error: {:?}", e); + return 1; + } + } + } + VcpuExit::CpuidAccess { + rax, + rcx, + default_rax, + default_rbx, + default_rcx, + default_rdx, + } => { + stats.halt_count = 0; + stats.cpuid_count += 1; + let (out_rax, out_rbx, out_rcx, out_rdx) = super::handle_cpuid( + 0, + num_vcpus, + rax as u32, + rcx, + default_rax, + default_rbx, + default_rcx, + default_rdx, + ); + log::trace!( + "BSP CPUID leaf=0x{:X} sub=0x{:X} -> rax=0x{:X}", + rax, + rcx, + out_rax + ); + if let Err(e) = vcpu.complete_cpuid(out_rax, out_rbx, out_rcx, out_rdx) { + log::error!("BSP complete_cpuid error: {:?}", e); + return 1; + } + } + VcpuExit::UnrecoverableException => { + let regs = vcpu.get_registers().ok(); + let sregs = vcpu.get_special_registers().ok(); + log::error!( + "BSP: Unrecoverable exception after {} exits. \ + RIP={:#X}, CR0={:#X}, CR3={:#X}, CR4={:#X}, EFER={:#X}", + stats.exit_count, + regs.as_ref().map_or(0, |r| r.rip), + sregs.as_ref().map_or(0, |s| s.cr0), + sregs.as_ref().map_or(0, |s| s.cr3), + sregs.as_ref().map_or(0, |s| s.cr4), + sregs.as_ref().map_or(0, |s| s.efer), + ); + return -1; + } + VcpuExit::Unknown(reason) => { + log::error!( + "BSP: Unknown exit reason {} after {} exits", + reason, + stats.exit_count + ); + return -1; + } + } + + if stats.exit_count >= MAX_EXITS { + log::warn!("BSP reached {} exit limit", MAX_EXITS); + return -1; + } + } + } + + /// AP (Application Processor, vCPU 1..N-1) loop. + /// + /// Waits for SIPI, configures initial registers, then runs a vCPU loop + /// similar to BSP but without timer ticking or block worker management. + #[allow(clippy::too_many_arguments)] + fn run_ap_loop( + ap_id: u8, + num_vcpus: u8, + vcpu: &WhpxVcpu, + devices: &Arc>, + guest_mem: &GuestMemory, + shutdown: &AtomicBool, + run_config: &VcpuRunConfig, + cancellers: &[VcpuCanceller], + startup: &ApStartupState, + _ctx_id: u32, + diag_log: &Arc>>, + my_lapic: &Arc>, + my_shared: &Arc, + all_shared: &[Arc], + vcpu_running_flag: &AtomicBool, + ) { + macro_rules! diag { + ($($arg:tt)*) => { + if let Ok(mut guard) = diag_log.lock() { + if let Some(ref mut f) = *guard { + let _ = writeln!(f, $($arg)*); + let _ = f.flush(); + } + } + }; + } + + diag!("AP{}: thread started, waiting for SIPI", ap_id); + + // Wait for SIPI from BSP. + { + let mut started = startup.started.lock().unwrap(); + while !*started { + started = startup.condvar.wait(started).unwrap(); + } + } + + // Check if we were woken for shutdown rather than SIPI. + if shutdown.load(Ordering::Relaxed) { + diag!("AP{}: woken for shutdown, not SIPI", ap_id); + return; + } + + // Configure AP initial register state from SIPI vector. + let sipi_vector = startup.sipi_vector.lock().unwrap().unwrap_or(0); + diag!( + "AP{}: SIPI received, vector={:#X}, CS:IP={:#X}:0000", + ap_id, + sipi_vector, + (sipi_vector as u64) * 0x1000 + ); + + // Dump WHPX default registers before modification for diagnostics. + if let Ok(sregs) = vcpu.get_special_registers() { + diag!( + "AP{}: WHPX defaults TR=sel:{:#X}/base:{:#X}/lim:{:#X}/ar:{:#X} \ + LDT=sel:{:#X}/base:{:#X}/lim:{:#X}/ar:{:#X} \ + GDT=base:{:#X}/lim:{:#X} IDT=base:{:#X}/lim:{:#X} \ + CR0={:#X} CR4={:#X} EFER={:#X}", + ap_id, + sregs.tr.selector, + sregs.tr.base, + sregs.tr.limit, + sregs.tr.access_rights, + sregs.ldt.selector, + sregs.ldt.base, + sregs.ldt.limit, + sregs.ldt.access_rights, + sregs.gdt.base, + sregs.gdt.limit, + sregs.idt.base, + sregs.idt.limit, + sregs.cr0, + sregs.cr4, + sregs.efer + ); + } + + // AP starts in real mode: CS:IP = (sipi_vector * 0x100):0x0000 + // The Linux kernel SMP trampoline is placed at sipi_vector * 0x1000. + if let Err(e) = vcpu.set_ap_initial_regs(sipi_vector, ap_id) { + diag!("AP{}: FAILED to set initial registers: {:?}", ap_id, e); + return; + } + diag!("AP{}: initial regs set, entering run loop", ap_id); + + // Mark AP as running so the timer thread can cancel it. + // This MUST happen after SIPI wake + register setup, just before first vcpu.run(). + vcpu_running_flag.store(true, Ordering::Release); + + let mut stats = VcpuStats::new(); + + loop { + if shutdown.load(Ordering::Relaxed) || !run_config.should_run() { + diag!( + "AP{}: EXIT (shutdown) exits={} cancelled={} halt={} cpuid={} mmio={}", + ap_id, + stats.exit_count, + stats.cancelled_count, + stats.total_halt_exits, + stats.cpuid_count, + stats.mmio_count, + ); + return; + } + + // Pull remote interrupts (lock-free) + tick own LAPIC timer. + { + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + if let Some(vector) = lapic.tick_timer(Instant::now()) { + lapic.accept_interrupt(vector); + } + } + // Lock-free interrupt injection (APIC mode, per-vCPU only). + if let Err(e) = try_inject_interrupt_fast(vcpu, my_lapic, &mut stats) { + log::error!("AP{}: lock-free inject error: {:?}", ap_id, e); + } + + let exit = match vcpu.run() { + Ok(exit) => exit, + Err(e) => { + diag!( + "AP{}: vcpu.run() FAILED after {} exits: {:?}", + ap_id, + stats.exit_count, + e + ); + return; + } + }; + stats.exit_count += 1; + + // Log first few AP exits for diagnostics. + if stats.exit_count <= 10 { + let desc = match &exit { + VcpuExit::IoOut { port, .. } => format!("IoOut(port={:#X})", port), + VcpuExit::IoIn { port, .. } => format!("IoIn(port={:#X})", port), + VcpuExit::MmioRead { address, .. } => format!("MmioRead({:#X})", address), + VcpuExit::MmioWrite { address, .. } => format!("MmioWrite({:#X})", address), + VcpuExit::Halt => "Halt".into(), + VcpuExit::Cancelled => "Cancelled".into(), + VcpuExit::InterruptWindow => "InterruptWindow".into(), + VcpuExit::Shutdown => "Shutdown".into(), + VcpuExit::UnrecoverableException => "UnrecoverableException".into(), + VcpuExit::MsrAccess { + msr_number, + is_write, + .. + } => format!("MSR({:#X}, write={})", msr_number, is_write), + VcpuExit::CpuidAccess { rax, .. } => format!("CPUID({:#X})", rax), + VcpuExit::Unknown(r) => format!("Unknown({})", r), + }; + diag!("AP{}: exit #{} = {}", ap_id, stats.exit_count, desc); + } + + match exit { + VcpuExit::IoOut { port, size, data } => { + stats.halt_count = 0; + stats.io_out_count += 1; + let mut dm = devices.lock().unwrap(); + dm.handle_io_out(port, size, data); + if dm.shutdown_requested() { + log::info!("AP{}: ACPI shutdown detected", ap_id); + shutdown.store(true, Ordering::Release); + for c in cancellers { + let _ = c.cancel(); + } + return; + } + drop(dm); + let _ = vcpu.skip_instruction(); + } + VcpuExit::IoIn { port, size } => { + stats.halt_count = 0; + stats.io_in_count += 1; + let data = devices.lock().unwrap().handle_io_in(port, size); + let _ = vcpu.complete_io_in(data, size); + } + VcpuExit::MmioRead { address, size } => { + stats.halt_count = 0; + stats.mmio_count += 1; + // Detect tight MMIO read loops (same address read 10K+ times). + if address == stats.last_mmio_read_addr { + stats.consecutive_mmio_reads += 1; + if stats.consecutive_mmio_reads == 10_000 { + log::warn!( + "AP{}: tight MMIO read loop: addr={:#X} count={} exit={}", + ap_id, + address, + stats.consecutive_mmio_reads, + stats.exit_count + ); + if let Ok(regs) = vcpu.get_registers() { + log::warn!("AP{}: RIP={:#X} at tight MMIO loop", ap_id, regs.rip); + } + } + } else { + stats.last_mmio_read_addr = address; + stats.consecutive_mmio_reads = 1; + } + // Fast path: LAPIC reads bypass DeviceManager lock. + let data = if let Some(val) = handle_lapic_mmio_read_fast(my_lapic, address) { + val + } else { + devices + .lock() + .unwrap() + .handle_mmio_read(ap_id, address, size) + }; + let _ = vcpu.complete_mmio_read(data); + } + VcpuExit::MmioWrite { + address, + size, + data, + } => { + stats.halt_count = 0; + stats.mmio_count += 1; + match handle_lapic_mmio_write_fast(my_lapic, address, data) { + LapicWriteFastResult::Handled => { + // Fast path: handled via per-vCPU lock only. + } + LapicWriteFastResult::IpiAction(action) => { + // ICR fast path: dispatch IPI inline (lock-free). + match action { + IpiAction::SendInterrupt { + target_apic_id, + vector, + } => { + let idx = target_apic_id as usize; + if idx < all_shared.len() { + all_shared[idx].request_interrupt(vector); + if idx < cancellers.len() { + let _ = cancellers[idx].cancel(); + } + } + } + IpiAction::BroadcastInterrupt { + source_apic_id, + vector, + } => { + // Broadcast to all vCPUs except source (lock-free). + for idx in 0..all_shared.len() { + if idx as u8 != source_apic_id { + all_shared[idx].request_interrupt(vector); + if idx < cancellers.len() { + let _ = cancellers[idx].cancel(); + } + } + } + } + IpiAction::SendInit { .. } | IpiAction::SendSipi { .. } => { + dispatch_ipi( + action, + &mut devices.lock().unwrap(), + &[], + cancellers, + diag_log, + stats.start_time, + ); + } + IpiAction::None => {} + } + } + LapicWriteFastResult::NotHandled => { + // Slow path: needs DeviceManager for EOI/SVR or non-LAPIC devices. + let mut dm = devices.lock().unwrap(); + let ipi_action = + dm.handle_mmio_write(ap_id, address, size, data, guest_mem); + if !matches!(ipi_action, IpiAction::None) { + dispatch_ipi( + ipi_action, + &mut dm, + &[], + cancellers, + diag_log, + stats.start_time, + ); + } + drop(dm); + } + } + let _ = vcpu.skip_instruction(); + } + VcpuExit::InterruptWindow => { + stats.halt_count = 0; + stats.window_requested = false; + } + VcpuExit::Halt => { + stats.total_halt_exits += 1; + if shutdown.load(Ordering::Relaxed) || !run_config.should_run() { + log::info!("AP{}: stop requested on Halt", ap_id); + return; + } + + // Pull remote interrupts + tick timer before checking. + { + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + if let Some(vector) = lapic.tick_timer(Instant::now()) { + lapic.accept_interrupt(vector); + } + } + // Lock-free check: inject from per-vCPU LAPIC only. + if my_lapic.lock().unwrap().get_highest_injectable().is_some() { + if let Err(e) = try_inject_interrupt_fast(vcpu, my_lapic, &mut stats) { + log::error!("AP{}: HLT inject error: {:?}", ap_id, e); + } + stats.halt_with_irq += 1; + stats.halt_count = 0; + continue; + } + + stats.halt_count += 1; + + if stats.halt_count > MAX_HALTS { + log::info!( + "AP{}: halted {} times, idling (not treated as fatal)", + ap_id, + stats.halt_count, + ); + // APs can idle indefinitely — the kernel may park them. + // Don't exit, just keep waiting for interrupts. + stats.halt_count = 0; + } + + // Tiered sleep: spin-yield phase to catch imminent interrupts, + // then short sleep if no interrupt arrived. + // Fast path: pull_irr + per-LAPIC only (no DeviceManager lock). + let mut woke_by_irq = false; + for i in 0..HLT_SPIN_ITERS { + std::thread::yield_now(); + if i % 10 == 9 { + let mut lapic = my_lapic.lock().unwrap(); + lapic.pull_irr(my_shared); + if lapic.get_highest_injectable().is_some() { + woke_by_irq = true; + break; + } + } + } + if !woke_by_irq { + std::thread::sleep(Duration::from_micros(HLT_SLEEP_US)); + } + } + VcpuExit::Shutdown => { + log::info!("AP{}: shutdown after {} exits", ap_id, stats.exit_count); + return; + } + VcpuExit::Cancelled => { + if shutdown.load(Ordering::Relaxed) || !run_config.should_run() { + log::info!("AP{}: stop requested on Cancelled", ap_id); + return; + } + stats.cancelled_count += 1; + // Periodic AP progress logging (every 500 Cancelled exits ≈ every 500ms). + if stats.cancelled_count % 500 == 0 { + let rip = vcpu.get_registers().map(|r| r.rip).unwrap_or(0xDEAD); + diag!( + "AP{} @ {:.1}s: exits={} cancelled={} halt={} cpuid={} mmio={} RIP={:#X}", + ap_id, + stats.start_time.elapsed().as_secs_f64(), + stats.exit_count, + stats.cancelled_count, + stats.total_halt_exits, + stats.cpuid_count, + stats.mmio_count, + rip, + ); + } + } + VcpuExit::MsrAccess { + msr_number, + is_write, + rax, + rdx, + } => { + stats.halt_count = 0; + if is_write { + log::trace!( + "AP{}: MSR write 0x{:08X} <- 0x{:016X}", + ap_id, + msr_number, + (rdx << 32) | (rax & 0xFFFF_FFFF) + ); + if let Err(e) = vcpu.skip_instruction() { + log::error!("AP{} skip_instruction error: {:?}", ap_id, e); + return; + } + } else { + let value = super::handle_msr_read(ap_id, msr_number); + log::trace!( + "AP{}: MSR read 0x{:08X} -> 0x{:X}", + ap_id, + msr_number, + value + ); + if let Err(e) = vcpu.complete_msr_read(value) { + log::error!("AP{} complete_msr_read error: {:?}", ap_id, e); + return; + } + } + } + VcpuExit::CpuidAccess { + rax, + rcx, + default_rax, + default_rbx, + default_rcx, + default_rdx, + } => { + stats.halt_count = 0; + stats.cpuid_count += 1; + let (out_rax, out_rbx, out_rcx, out_rdx) = super::handle_cpuid( + ap_id, + num_vcpus, + rax as u32, + rcx, + default_rax, + default_rbx, + default_rcx, + default_rdx, + ); + log::trace!( + "AP{} CPUID leaf=0x{:X} sub=0x{:X} -> rax=0x{:X}", + ap_id, + rax, + rcx, + out_rax + ); + if let Err(e) = vcpu.complete_cpuid(out_rax, out_rbx, out_rcx, out_rdx) { + log::error!("AP{} complete_cpuid error: {:?}", ap_id, e); + return; + } + } + VcpuExit::UnrecoverableException => { + let regs = vcpu.get_registers().ok(); + let sregs = vcpu.get_special_registers().ok(); + diag!( + "AP{}: TRIPLE FAULT after {} exits, RIP={:#X}, CR0={:#X}, CR3={:#X}, EFER={:#X}", + ap_id, + stats.exit_count, + regs.as_ref().map_or(0, |r| r.rip), + sregs.as_ref().map_or(0, |s| s.cr0), + sregs.as_ref().map_or(0, |s| s.cr3), + sregs.as_ref().map_or(0, |s| s.efer), + ); + return; + } + VcpuExit::Unknown(reason) => { + diag!( + "AP{}: unknown exit reason {} after {} exits", + ap_id, + reason, + stats.exit_count + ); + return; + } + } + + if stats.exit_count >= MAX_EXITS { + diag!( + "AP{}: EXIT (max_exits) exits={} cancelled={} halt={} cpuid={} mmio={}", + ap_id, + stats.exit_count, + stats.cancelled_count, + stats.total_halt_exits, + stats.cpuid_count, + stats.mmio_count, + ); + return; + } + } + } + + /// Run a VM synchronously on the calling thread (blocking). + /// + /// Used by `wkrun_start_enter()`. Creates a default `VcpuRunConfig` and + /// runs the vCPU loop until the guest shuts down or an error occurs. + pub fn run(ctx: VmContext) -> Result { + let ctx_id = ctx.id; + let run_config = VcpuRunConfig::new(); + let canceller_slot = Arc::new(Mutex::new(None)); + let result = run_vcpu_loop(ctx, run_config, canceller_slot); + devices::remove_console_buffer(ctx_id); + result + } + + /// Start a VM on a background thread (non-blocking). + /// + /// Takes ownership of the context and spawns a thread running the vCPU loop. + /// Use `wait()` to block until the VM exits, or `stop()` to request shutdown. + pub fn start(ctx_id: u32, ctx: VmContext) -> Result<()> { + let run_config = VcpuRunConfig::new(); + let canceller_slot: Arc>> = Arc::new(Mutex::new(None)); + + let rc = run_config.clone(); + let cs = canceller_slot.clone(); + let thread = std::thread::spawn(move || run_vcpu_loop(ctx, rc, cs)); + + let handle = VmHandle { + thread: Some(thread), + run_config, + canceller: canceller_slot, + }; + + let mut map = RUNNING_VMS + .lock() + .map_err(|_| WkrunError::Config("running VMs lock poisoned".into()))?; + if map.contains_key(&ctx_id) { + return Err(WkrunError::Config(format!( + "VM {} is already running", + ctx_id + ))); + } + map.insert(ctx_id, handle); + Ok(()) + } + + /// Block until a running VM exits. Returns the guest exit code. + /// + /// Removes the VM from the running registry. After `wait()` returns, + /// the ctx_id is no longer valid. + pub fn wait(ctx_id: u32) -> Result { + let mut map = RUNNING_VMS + .lock() + .map_err(|_| WkrunError::Config("running VMs lock poisoned".into()))?; + let mut handle = map + .remove(&ctx_id) + .ok_or(WkrunError::InvalidContext(ctx_id))?; + drop(map); // Release lock before blocking join. + + let thread = handle + .thread + .take() + .ok_or_else(|| WkrunError::Config("VM thread already joined".into()))?; + let result = thread + .join() + .map_err(|_| WkrunError::Config("VM thread panicked".into()))?; + devices::remove_console_buffer(ctx_id); + result + } + + /// Request a running VM to stop (non-blocking). + /// + /// Sets the stop flag and wakes the vCPU so it exits promptly. + /// The VM thread will exit on its next Halt or Cancelled check. + /// Call `wait()` afterwards to collect the exit code. + pub fn stop(ctx_id: u32) -> Result<()> { + let map = RUNNING_VMS + .lock() + .map_err(|_| WkrunError::Config("running VMs lock poisoned".into()))?; + let handle = map.get(&ctx_id).ok_or(WkrunError::InvalidContext(ctx_id))?; + handle.run_config.request_stop(); + if let Some(ref canceller) = *handle.canceller.lock().unwrap() { + let _ = canceller.cancel(); + } + Ok(()) + } +} + +#[cfg(target_os = "windows")] +pub use imp::{run, start, stop, wait}; + +/// Stub for non-Windows platforms (compile only, never called). +#[cfg(not(target_os = "windows"))] +pub fn run(_ctx: super::context::VmContext) -> super::error::Result { + Err(super::error::WkrunError::Config( + "VM runner is only available on Windows".into(), + )) +} + +/// Stub for non-Windows platforms. +#[cfg(not(target_os = "windows"))] +pub fn start(_ctx_id: u32, _ctx: super::context::VmContext) -> super::error::Result<()> { + Err(super::error::WkrunError::Config( + "VM runner is only available on Windows".into(), + )) +} + +/// Stub for non-Windows platforms. +#[cfg(not(target_os = "windows"))] +pub fn wait(_ctx_id: u32) -> super::error::Result { + Err(super::error::WkrunError::Config( + "VM runner is only available on Windows".into(), + )) +} + +/// Stub for non-Windows platforms. +#[cfg(not(target_os = "windows"))] +pub fn stop(_ctx_id: u32) -> super::error::Result<()> { + Err(super::error::WkrunError::Config( + "VM runner is only available on Windows".into(), + )) +} + +/// Handle CPUID exit for any vCPU. +/// +/// Injects CPU topology info into leaf 1 and masks Hyper-V leaves. +/// This is a pure function (no side effects) for testability. +/// +/// `input_rcx` is the guest's original ECX value (sub-leaf number for leaves +/// 0xB/0x1F/4). This is distinct from `default_rcx` which is WHPX's computed +/// default OUTPUT for ECX. +fn handle_cpuid( + vcpu_id: u8, + num_vcpus: u8, + leaf: u32, + input_rcx: u64, + default_rax: u64, + default_rbx: u64, + default_rcx: u64, + default_rdx: u64, +) -> (u64, u64, u64, u64) { + match leaf { + // Leaf 1: feature info + topology. + // EBX[23:16] = max number of addressable APIC IDs (num_vcpus) + // EBX[31:24] = initial APIC ID (vcpu_id) + // ECX bit 31: clear "hypervisor present" + 1 => { + let mut ebx = default_rbx; + // Clear bits 31:16, then set topology fields. + ebx &= 0xFFFF_FFFF_0000_FFFF; + ebx |= (num_vcpus as u64) << 16; // max APIC IDs + ebx |= (vcpu_id as u64) << 24; // initial APIC ID + ( + default_rax, + ebx, + default_rcx & !(1u64 << 31), // clear hypervisor present + default_rdx, + ) + } + // Leaf 0xB / 0x1F: Extended Topology Enumeration. + // + // WHPX passes through the HOST topology (e.g., 4C/8T on i5-1135G7), + // which confuses the guest kernel when num_vcpus differs from the host. + // The kernel's parse_topology_leaf() loops over sub-leaves calling + // cpuid_subleaf() until type==0 (INVALID). If the host topology reports + // more logical processors than the guest has, the kernel hangs in + // topology parsing (BSP stuck in parse_topology_leaf at 4+ vCPUs). + // + // We override to present a flat topology: 1 thread per core, num_vcpus + // cores, no HT. This matches what the MADT advertises. + 0xB | 0x1F => { + let subleaf = input_rcx & 0xFF; // guest's ECX input = sub-leaf number + match subleaf { + // Sub-leaf 0: SMT level — 1 thread per core (no hyperthreading). + 0 => { + let eax = 0u64; // shift = 0 (1 thread per core) + let ebx = 1u64; // 1 logical processor at this level + let ecx = (1u64 << 8) | subleaf; // type=1 (SMT), level=0 + let edx = vcpu_id as u64; // x2APIC ID + (eax, ebx, ecx, edx) + } + // Sub-leaf 1: Core level — num_vcpus cores total. + 1 => { + // shift = ceil(log2(num_vcpus)): bits to shift right to get + // package-level ID from x2APIC ID. + let shift = if num_vcpus <= 1 { + 0u64 + } else { + (num_vcpus as u64).next_power_of_two().trailing_zeros() as u64 + }; + let eax = shift; + let ebx = num_vcpus as u64; // total logical processors + let ecx = (2u64 << 8) | subleaf; // type=2 (Core), level=1 + let edx = vcpu_id as u64; // x2APIC ID + (eax, ebx, ecx, edx) + } + // Sub-leaf 2+: invalid — terminates the kernel's enumeration loop. + _ => { + let ecx = subleaf; // type=0 (INVALID), level=subleaf + (0, 0, ecx, vcpu_id as u64) + } + } + } + // Leaf 4: Deterministic Cache Parameters. + // + // Host reports max_cores_in_package (EAX[31:26]) and max_threads_sharing + // (EAX[25:14]) based on host topology. Override to match guest vCPU count + // so cache topology is consistent with leaf 0xB. + 4 => { + let cache_type = default_rax & 0x1F; + if cache_type == 0 { + // No more cache levels. + (default_rax, default_rbx, default_rcx, default_rdx) + } else { + let mut eax = default_rax; + // EAX[25:14] = max threads sharing this cache - 1. + // For L1/L2: 0 (not shared). For L3: num_vcpus - 1 (shared). + let max_sharing = if (default_rax & 0x1F) == 3 { + // Unified cache (L3): shared by all vCPUs. + (num_vcpus as u64).saturating_sub(1) + } else { + 0 // L1/L2: per-core, not shared. + }; + eax = (eax & !(0xFFF << 14)) | (max_sharing << 14); + // EAX[31:26] = max cores in package - 1. + eax = (eax & !(0x3F << 26)) | (((num_vcpus as u64).saturating_sub(1)) << 26); + (eax, default_rbx, default_rcx, default_rdx) + } + } + // Hyper-V CPUID range: return zeros. + 0x40000000..=0x400000FF => (0, 0, 0, 0), + _ => (default_rax, default_rbx, default_rcx, default_rdx), + } +} + +/// Handle MSR read for any vCPU. +/// +/// Returns the value to inject for the given MSR. +/// IA32_APIC_BASE (0x1B) returns the APIC base address with enable + BSP bits. +fn handle_msr_read(vcpu_id: u8, msr_number: u32) -> u64 { + if msr_number == 0x1B { + let mut val: u64 = 0xFEE0_0000 | (1 << 11); // APIC base + enable bit + if vcpu_id == 0 { + val |= 1 << 8; // BSP flag + } + val + } else { + 0 + } +} + +#[cfg(test)] +mod tests { + use super::super::context::VmContext; + use super::super::vcpu::VcpuRunConfig; + use super::*; + use std::sync::{Arc, Mutex}; + + #[test] + fn test_run_without_kernel_returns_error() { + // VmContext with no kernel path should fail. + let ctx = VmContext::default_for_test(); + let result = run(ctx); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + // On non-Windows: "only available on Windows" + // On Windows without kernel: "kernel_path is required" + assert!( + err.contains("kernel_path") || err.contains("Windows"), + "unexpected error: {}", + err + ); + } + + #[test] + fn test_start_without_kernel_returns_error() { + // start() should fail the same way as run() for missing kernel. + let ctx = VmContext::default_for_test(); + let result = start(99900, ctx); + + #[cfg(not(target_os = "windows"))] + { + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Windows")); + } + + #[cfg(target_os = "windows")] + { + // start() spawns a thread — the error surfaces in wait(). + // But on Windows, if WHPX isn't available or kernel is missing, + // we still get Ok(()) from start() since the thread handles it. + if result.is_ok() { + let wait_result = wait(99900); + assert!(wait_result.is_err()); + } + } + } + + #[test] + fn test_wait_invalid_id_returns_error() { + let result = wait(99901); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("invalid context") || err.contains("Windows"), + "unexpected error: {}", + err + ); + } + + #[test] + fn test_stop_invalid_id_returns_error() { + let result = stop(99902); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("invalid context") || err.contains("Windows"), + "unexpected error: {}", + err + ); + } + + #[test] + fn test_run_config_used_for_stop() { + // Verify VcpuRunConfig flag propagation (cross-platform). + let config = VcpuRunConfig::new(); + let cloned = config.clone(); + assert!(config.should_run()); + assert!(cloned.should_run()); + + cloned.request_stop(); + assert!(!config.should_run()); + } + + #[test] + fn test_canceller_slot_starts_none() { + // The canceller slot should start as None (cross-platform). + let slot: Arc>> = Arc::new(Mutex::new(None)); + assert!(slot.lock().unwrap().is_none()); + } + + #[cfg(target_os = "windows")] + #[test] + fn test_start_rejects_duplicate_ctx_id() { + // Use a unique ctx_id unlikely to collide with other tests. + let ctx_id = 99903; + let ctx = VmContext::default_for_test(); + // First start might succeed or fail (depending on WHPX availability). + let _ = start(ctx_id, ctx); + + let ctx2 = VmContext::default_for_test(); + let result = start(ctx_id, ctx2); + // If first succeeded, second should fail with "already running". + // Clean up. + let _ = stop(ctx_id); + let _ = wait(ctx_id); + + if result.is_err() { + assert!(result.unwrap_err().to_string().contains("already running")); + } + } + + #[cfg(target_os = "windows")] + #[test] + fn test_double_wait_returns_error() { + let ctx_id = 99904; + let ctx = VmContext::default_for_test(); + if start(ctx_id, ctx).is_ok() { + // First wait should succeed (thread exits with error due to no kernel). + let _ = wait(ctx_id); + // Second wait should fail — already removed. + let result = wait(ctx_id); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("invalid context")); + } + } + + #[cfg(target_os = "windows")] + #[test] + fn test_stop_after_wait_returns_error() { + let ctx_id = 99905; + let ctx = VmContext::default_for_test(); + if start(ctx_id, ctx).is_ok() { + let _ = wait(ctx_id); + // stop() after wait() should fail — already removed from registry. + let result = stop(ctx_id); + assert!(result.is_err()); + } + } + + // --- handle_cpuid tests --- + + #[test] + fn test_cpuid_leaf1_topology_bsp() { + // BSP (vcpu 0) with 2 vCPUs. + // input_rcx=0 (leaf 1 doesn't use sub-leaves). + let (rax, rbx, rcx, rdx) = super::handle_cpuid( + 0, + 2, + 1, + 0, + 0x1234, + 0x0000_0000_0000_5678, + 0x8000_0001, + 0xABCD, + ); + // EBX[23:16] = num_vcpus = 2, EBX[31:24] = vcpu_id = 0 + assert_eq!(rbx & 0x00FF_0000, 0x0002_0000, "EBX[23:16] should be 2"); + assert_eq!( + rbx & 0xFF00_0000, + 0x0000_0000, + "EBX[31:24] should be 0 for BSP" + ); + // EBX[15:0] preserved from default + assert_eq!(rbx & 0xFFFF, 0x5678, "EBX[15:0] should be preserved"); + // ECX bit 31 (hypervisor present) must be cleared + assert_eq!(rcx & (1 << 31), 0, "hypervisor present bit must be cleared"); + // RAX and RDX pass through + assert_eq!(rax, 0x1234); + assert_eq!(rdx, 0xABCD); + } + + #[test] + fn test_cpuid_leaf1_topology_ap() { + // AP (vcpu 3) with 4 vCPUs. + let (_, rbx, _, _) = super::handle_cpuid(3, 4, 1, 0, 0, 0, 0, 0); + assert_eq!((rbx >> 16) & 0xFF, 4, "EBX[23:16] should be num_vcpus=4"); + assert_eq!((rbx >> 24) & 0xFF, 3, "EBX[31:24] should be vcpu_id=3"); + } + + #[test] + fn test_cpuid_hyperv_leaves_zeroed() { + // Hyper-V CPUID range should return all zeros. + for leaf in [0x40000000u32, 0x40000001, 0x400000FF] { + let (rax, rbx, rcx, rdx) = + super::handle_cpuid(0, 1, leaf, 0, 0xDEAD, 0xBEEF, 0xCAFE, 0xF00D); + assert_eq!( + (rax, rbx, rcx, rdx), + (0, 0, 0, 0), + "Hyper-V leaf 0x{:X} must be zeroed", + leaf + ); + } + } + + #[test] + fn test_cpuid_passthrough_other_leaves() { + // Non-special leaves should pass through defaults unchanged. + let (rax, rbx, rcx, rdx) = super::handle_cpuid(0, 2, 0, 0, 0x1111, 0x2222, 0x3333, 0x4444); + assert_eq!((rax, rbx, rcx, rdx), (0x1111, 0x2222, 0x3333, 0x4444)); + + let (rax, rbx, rcx, rdx) = super::handle_cpuid(0, 2, 7, 0, 0xAAAA, 0xBBBB, 0xCCCC, 0xDDDD); + assert_eq!((rax, rbx, rcx, rdx), (0xAAAA, 0xBBBB, 0xCCCC, 0xDDDD)); + } + + // --- CPUID leaf 0xB tests --- + + #[test] + fn test_cpuid_leaf_0xb_smt_level() { + // Sub-leaf 0 = SMT: shift=0, np=1, type=1, edx=vcpu_id. + // input_rcx=0 (sub-leaf 0). + let (rax, rbx, rcx, rdx) = super::handle_cpuid(0, 4, 0xB, 0, 0, 0, 0, 0); + assert_eq!(rax & 0x1F, 0, "SMT shift should be 0 (no HT)"); + assert_eq!(rbx & 0xFFFF, 1, "SMT should report 1 logical proc"); + assert_eq!((rcx >> 8) & 0xFF, 1, "type should be 1 (SMT)"); + assert_eq!(rdx, 0, "x2APIC ID should be vcpu_id=0"); + + // Same for vcpu 3. + let (_, _, _, rdx) = super::handle_cpuid(3, 4, 0xB, 0, 0, 0, 0, 0); + assert_eq!(rdx, 3, "x2APIC ID should be vcpu_id=3"); + } + + #[test] + fn test_cpuid_leaf_0xb_core_level_4vcpus() { + // Sub-leaf 1 = Core: shift=ceil(log2(4))=2, np=4, type=2. + // input_rcx=1 (sub-leaf 1). + let (rax, rbx, rcx, rdx) = super::handle_cpuid(0, 4, 0xB, 1, 0, 0, 0, 0); + assert_eq!(rax & 0x1F, 2, "Core shift should be 2 for 4 vCPUs"); + assert_eq!(rbx & 0xFFFF, 4, "Core should report 4 logical procs"); + assert_eq!((rcx >> 8) & 0xFF, 2, "type should be 2 (Core)"); + assert_eq!(rdx, 0, "x2APIC ID should be vcpu_id=0"); + } + + #[test] + fn test_cpuid_leaf_0xb_core_level_2vcpus() { + let (rax, rbx, _, _) = super::handle_cpuid(1, 2, 0xB, 1, 0, 0, 0, 0); + assert_eq!(rax & 0x1F, 1, "Core shift should be 1 for 2 vCPUs"); + assert_eq!(rbx & 0xFFFF, 2, "Core should report 2 logical procs"); + } + + #[test] + fn test_cpuid_leaf_0xb_core_level_1vcpu() { + let (rax, rbx, _, _) = super::handle_cpuid(0, 1, 0xB, 1, 0, 0, 0, 0); + assert_eq!(rax & 0x1F, 0, "Core shift should be 0 for 1 vCPU"); + assert_eq!(rbx & 0xFFFF, 1, "Core should report 1 logical proc"); + } + + #[test] + fn test_cpuid_leaf_0xb_invalid_subleaf() { + // Sub-leaf 2+ should return type=0 (INVALID) to terminate kernel loop. + let (rax, rbx, rcx, _) = super::handle_cpuid(0, 4, 0xB, 2, 0, 0, 0, 0); + assert_eq!(rax, 0); + assert_eq!(rbx, 0); + assert_eq!((rcx >> 8) & 0xFF, 0, "type should be 0 (INVALID)"); + } + + #[test] + fn test_cpuid_leaf_0x1f_same_as_0xb() { + // Leaf 0x1F should produce identical results to 0xB. + for subleaf in 0..3u64 { + let r_b = super::handle_cpuid(0, 4, 0xB, subleaf, 0, 0, 0, 0); + let r_1f = super::handle_cpuid(0, 4, 0x1F, subleaf, 0, 0, 0, 0); + assert_eq!( + r_b, r_1f, + "Leaf 0xB and 0x1F should match for sub-leaf {}", + subleaf + ); + } + } + + #[test] + fn test_cpuid_leaf4_cache_topology() { + // Leaf 4 with cache_type != 0 should override max_cores and max_threads. + // Simulate L1 data cache (type=1) with host values. + let host_eax: u64 = 1 // cache_type = 1 (data) + | (7 << 14) // max_threads_sharing = 8 (host value) + | (7 << 26); // max_cores = 8 (host value) + let (rax, _, _, _) = super::handle_cpuid(0, 4, 4, 0, host_eax, 0, 0, 0); + // For L1 (non-unified): max_threads_sharing should be 0 (per-core). + assert_eq!((rax >> 14) & 0xFFF, 0, "L1 max_threads_sharing should be 0"); + // max_cores should be num_vcpus - 1 = 3. + assert_eq!((rax >> 26) & 0x3F, 3, "max_cores should be num_vcpus-1=3"); + + // Simulate L3 unified cache (type=3). + let host_eax: u64 = 3 // cache_type = 3 (unified) + | (15 << 14) // max_threads_sharing = 16 (host) + | (7 << 26); // max_cores = 8 (host) + let (rax, _, _, _) = super::handle_cpuid(0, 4, 4, 2, host_eax, 0, 0, 0); + // L3: max_threads_sharing should be num_vcpus - 1 = 3. + assert_eq!((rax >> 14) & 0xFFF, 3, "L3 max_threads_sharing should be 3"); + assert_eq!((rax >> 26) & 0x3F, 3, "max_cores should be num_vcpus-1=3"); + } + + #[test] + fn test_cpuid_leaf4_no_cache_passthrough() { + // cache_type = 0 means no more caches — pass through unchanged. + let (rax, rbx, rcx, rdx) = super::handle_cpuid(0, 4, 4, 0, 0, 0xBEEF, 0xCAFE, 0xDEAD); + assert_eq!((rax, rbx, rcx, rdx), (0, 0xBEEF, 0xCAFE, 0xDEAD)); + } + + // --- handle_msr_read tests --- + + #[test] + fn test_msr_apic_base_bsp() { + // BSP should have enable + BSP flag. + let val = super::handle_msr_read(0, 0x1B); + assert_eq!(val & 0xFFFFF000, 0xFEE0_0000, "APIC base address"); + assert_ne!(val & (1 << 11), 0, "APIC enable bit must be set"); + assert_ne!(val & (1 << 8), 0, "BSP flag must be set for vcpu 0"); + } + + #[test] + fn test_msr_apic_base_ap() { + // AP should have enable but NOT BSP flag. + let val = super::handle_msr_read(1, 0x1B); + assert_eq!(val & 0xFFFFF000, 0xFEE0_0000, "APIC base address"); + assert_ne!(val & (1 << 11), 0, "APIC enable bit must be set"); + assert_eq!(val & (1 << 8), 0, "BSP flag must NOT be set for AP"); + } + + #[test] + fn test_msr_unknown_returns_zero() { + // Unknown MSR should return 0. + assert_eq!(super::handle_msr_read(0, 0x174), 0); + assert_eq!(super::handle_msr_read(1, 0xC000_0080), 0); + } +} diff --git a/src/vmm/src/windows/types.rs b/src/vmm/src/windows/types.rs new file mode 100644 index 000000000..a1005bfd9 --- /dev/null +++ b/src/vmm/src/windows/types.rs @@ -0,0 +1,135 @@ +//! Common types for the Windows WHPX VMM layer. + +/// x86_64 standard registers (general-purpose + instruction pointer + flags). +#[derive(Debug, Default, Clone, Copy)] +#[repr(C)] +pub struct StandardRegisters { + pub rax: u64, + pub rbx: u64, + pub rcx: u64, + pub rdx: u64, + pub rsi: u64, + pub rdi: u64, + pub rsp: u64, + pub rbp: u64, + pub r8: u64, + pub r9: u64, + pub r10: u64, + pub r11: u64, + pub r12: u64, + pub r13: u64, + pub r14: u64, + pub r15: u64, + pub rip: u64, + pub rflags: u64, +} + +/// x86_64 segment register. +#[derive(Debug, Default, Clone, Copy)] +#[repr(C)] +pub struct SegmentRegister { + pub base: u64, + pub limit: u32, + pub selector: u16, + /// Access rights (type + S + DPL + P + AVL + L + D/B + G). + pub access_rights: u16, +} + +/// x86_64 descriptor table register (GDTR, IDTR). +#[derive(Debug, Default, Clone, Copy)] +#[repr(C)] +pub struct DescriptorTable { + pub base: u64, + pub limit: u16, +} + +/// x86_64 special/system registers. +#[derive(Debug, Default, Clone, Copy)] +#[repr(C)] +pub struct SpecialRegisters { + pub cs: SegmentRegister, + pub ds: SegmentRegister, + pub es: SegmentRegister, + pub fs: SegmentRegister, + pub gs: SegmentRegister, + pub ss: SegmentRegister, + pub tr: SegmentRegister, + pub ldt: SegmentRegister, + pub gdt: DescriptorTable, + pub idt: DescriptorTable, + pub cr0: u64, + pub cr2: u64, + pub cr3: u64, + pub cr4: u64, + pub efer: u64, +} + +/// Reason the vCPU exited back to the VMM. +#[derive(Debug)] +pub enum VcpuExit { + /// Guest performed an I/O port read. + IoIn { port: u16, size: u8 }, + /// Guest performed an I/O port write. + IoOut { port: u16, size: u8, data: u32 }, + /// Guest performed an MMIO read. + MmioRead { address: u64, size: u8 }, + /// Guest performed an MMIO write. + MmioWrite { address: u64, size: u8, data: u64 }, + /// Guest executed HLT instruction. + Halt, + /// VM shutdown requested. + Shutdown, + /// Hypervisor cancelled the run (e.g., stop requested). + Cancelled, + /// Interrupt window available (guest RFLAGS.IF became 1). + InterruptWindow, + /// Guest executed RDMSR/WRMSR (requires ExtendedVmExits.X64MsrExit). + MsrAccess { + msr_number: u32, + is_write: bool, + /// RAX value (contains write data for WRMSR, undefined for RDMSR). + rax: u64, + /// RDX value (contains write data for WRMSR, undefined for RDMSR). + rdx: u64, + }, + /// Guest executed CPUID (requires ExtendedVmExits.X64CpuidExit). + CpuidAccess { + /// Input: EAX (leaf). + rax: u64, + /// Input: ECX (sub-leaf). + rcx: u64, + /// Default results from host CPUID (pass-through values from WHPX). + default_rax: u64, + default_rbx: u64, + default_rcx: u64, + default_rdx: u64, + }, + /// Unrecoverable guest exception (triple fault). + UnrecoverableException, + /// Exit reason not handled. + Unknown(u32), +} + +/// VM lifecycle state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VmState { + /// VM context created, accepting configuration. + Created, + /// VM is configured and ready to start. + Configured, + /// VM is running. + Running, + /// VM has stopped. + Stopped, +} + +impl std::fmt::Display for VmState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VmState::Created => write!(f, "Created"), + VmState::Configured => write!(f, "Configured"), + VmState::Running => write!(f, "Running"), + VmState::Stopped => write!(f, "Stopped"), + } + } +} diff --git a/src/vmm/src/windows/vcpu.rs b/src/vmm/src/windows/vcpu.rs new file mode 100644 index 000000000..4f679d707 --- /dev/null +++ b/src/vmm/src/windows/vcpu.rs @@ -0,0 +1,112 @@ +//! vCPU thread management for the Windows WHPX backend. + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +/// Callback for handling I/O port accesses from the guest. +pub trait IoHandler: Send + Sync { + /// Handle an I/O port read. Returns the data to inject into the guest. + fn io_read(&self, port: u16, size: u8) -> u32; + + /// Handle an I/O port write from the guest. + fn io_write(&self, port: u16, size: u8, data: u32); +} + +/// Callback for handling MMIO accesses from the guest. +pub trait MmioHandler: Send + Sync { + /// Handle an MMIO read. Returns the data to inject into the guest. + fn mmio_read(&self, address: u64, size: u8) -> u64; + + /// Handle an MMIO write from the guest. + fn mmio_write(&self, address: u64, size: u8, data: u64); +} + +/// Shared state for a vCPU run loop. +pub struct VcpuRunConfig { + /// Whether the VM should keep running (set to false to request stop). + pub running: Arc, +} + +impl Clone for VcpuRunConfig { + fn clone(&self) -> Self { + VcpuRunConfig { + running: self.running.clone(), + } + } +} + +impl Default for VcpuRunConfig { + fn default() -> Self { + Self::new() + } +} + +impl VcpuRunConfig { + /// Create a new vCPU run configuration. + pub fn new() -> Self { + VcpuRunConfig { + running: Arc::new(AtomicBool::new(true)), + } + } + + /// Request the vCPU to stop running. + pub fn request_stop(&self) { + self.running.store(false, Ordering::Release); + } + + /// Check if the vCPU should continue running. + pub fn should_run(&self) -> bool { + self.running.load(Ordering::Acquire) + } +} + +/// Result of a vCPU run loop iteration. +#[derive(Debug, PartialEq, Eq)] +pub enum VcpuAction { + /// Continue running the vCPU. + Continue, + /// vCPU should halt (HLT instruction). + Halt, + /// VM should shut down. + Shutdown, + /// Run was cancelled externally. + Cancelled, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vcpu_run_config_lifecycle() { + let config = VcpuRunConfig::new(); + assert!(config.should_run()); + + config.request_stop(); + assert!(!config.should_run()); + } + + #[test] + fn test_vcpu_run_config_shared() { + let config = VcpuRunConfig::new(); + let running = config.running.clone(); + + assert!(running.load(Ordering::Acquire)); + config.request_stop(); + assert!(!running.load(Ordering::Acquire)); + } + + #[test] + fn test_vcpu_run_config_clone_shares_state() { + let config = VcpuRunConfig::new(); + let cloned = config.clone(); + + assert!(config.should_run()); + assert!(cloned.should_run()); + + // Stopping the clone stops the original (shared Arc). + cloned.request_stop(); + assert!(!config.should_run()); + assert!(!cloned.should_run()); + } +} diff --git a/src/vmm/src/windows/whpx.rs b/src/vmm/src/windows/whpx.rs new file mode 100644 index 000000000..d4448655e --- /dev/null +++ b/src/vmm/src/windows/whpx.rs @@ -0,0 +1,1146 @@ +//! WHPX (Windows Hypervisor Platform) backend. +//! +//! Safe Rust wrappers around the WHPX C API for creating and managing +//! VM partitions and virtual processors. + +#[cfg(not(target_os = "windows"))] +compile_error!("WHPX module requires Windows"); + +#[cfg(target_os = "windows")] +mod imp { + use std::cell::Cell; + use std::ptr; + + use windows_sys::Win32::System::Hypervisor::*; + + use super::super::error::{check_hresult, Result}; + use super::super::types::{SpecialRegisters, StandardRegisters, VcpuExit}; + + // Helper: create a zeroed WHV_REGISTER_VALUE (for arrays). + fn zeroed_reg_value() -> WHV_REGISTER_VALUE { + // SAFETY: WHV_REGISTER_VALUE is a union of integer/struct types; all-zeros is valid. + unsafe { std::mem::zeroed() } + } + + // Helper: create a WHV_REGISTER_VALUE from a u64 (for Reg64 field). + fn reg64(val: u64) -> WHV_REGISTER_VALUE { + WHV_REGISTER_VALUE { Reg64: val } + } + + // Helper: extract u64 from a WHV_REGISTER_VALUE Reg64 field. + // SAFETY: Caller must ensure the register contains a 64-bit value. + unsafe fn read_reg64(val: &WHV_REGISTER_VALUE) -> u64 { + val.Reg64 + } + + /// Bitfield accessors for WHV_X64_IO_PORT_ACCESS_INFO. + /// The _bitfield layout (from windows-sys): + /// bits [0..0] = IsWrite + /// bits [1..3] = AccessSize + /// bits [4..4] = StringOp + /// bits [5..5] = RepPrefix + /// bits [6..31] = Reserved + fn io_access_is_write(info: &WHV_X64_IO_PORT_ACCESS_INFO) -> bool { + let bits = unsafe { info.Anonymous._bitfield }; + (bits & 1) != 0 + } + + fn io_access_size(info: &WHV_X64_IO_PORT_ACCESS_INFO) -> u8 { + let bits = unsafe { info.Anonymous._bitfield }; + ((bits >> 1) & 0x7) as u8 + } + + /// Bitfield accessors for WHV_MEMORY_ACCESS_INFO. + /// The _bitfield layout: + /// bits [0..1] = AccessType (0=read, 1=write, 2=execute) + /// bits [2..2] = GpaUnmapped + /// bits [3..3] = GvaValid + /// bits [4..31] = Reserved + fn mem_access_type(info: &WHV_MEMORY_ACCESS_INFO) -> u32 { + let bits = unsafe { info.Anonymous._bitfield }; + bits & 0x3 + } + + /// Bitfield constants for WHV_EXTENDED_VM_EXITS. + /// Bit 0 = X64CpuidExit, Bit 1 = X64MsrExit. + const EXTENDED_VM_EXITS_CPUID: u64 = 1 << 0; + const EXTENDED_VM_EXITS_MSR: u64 = 1 << 1; + + /// Bitfield accessor for WHV_X64_MSR_ACCESS_INFO. + /// Bit 0 = IsWrite. + fn msr_access_is_write(info: &WHV_X64_MSR_ACCESS_INFO) -> bool { + let bits = unsafe { info.Anonymous._bitfield }; + (bits & 1) != 0 + } + + /// A WHPX partition (VM container). + /// + /// Wraps `WHV_PARTITION_HANDLE` and manages its lifecycle. + /// When dropped, the partition and all its resources are freed. + pub struct WhpxPartition { + handle: WHV_PARTITION_HANDLE, + } + + // SAFETY: WHPX partition handles can be shared across threads. + // The WHPX API is thread-safe for operations on different objects + // (e.g., different vCPUs within the same partition). + unsafe impl Send for WhpxPartition {} + unsafe impl Sync for WhpxPartition {} + + impl WhpxPartition { + /// Check if WHPX is available on this system. + pub fn is_available() -> Result { + let mut capability = WHV_CAPABILITY { + HypervisorPresent: 0, + }; + let hr = unsafe { + WHvGetCapability( + WHvCapabilityCodeHypervisorPresent, + &mut capability as *mut _ as *mut std::ffi::c_void, + std::mem::size_of::() as u32, + ptr::null_mut(), + ) + }; + check_hresult("WHvGetCapability", hr)?; + + // SAFETY: We requested WHvCapabilityCodeHypervisorPresent, + // so the union field HypervisorPresent is valid. + let present = unsafe { capability.HypervisorPresent }; + Ok(present != 0) + } + + /// Create a new WHPX partition. + pub fn new() -> Result { + // WHV_PARTITION_HANDLE is isize; 0 means invalid. + let mut handle: WHV_PARTITION_HANDLE = 0; + let hr = unsafe { WHvCreatePartition(&mut handle) }; + check_hresult("WHvCreatePartition", hr)?; + + Ok(WhpxPartition { handle }) + } + + /// Set the number of virtual processors for this partition. + pub fn set_processor_count(&self, count: u32) -> Result<()> { + let property = WHV_PARTITION_PROPERTY { + ProcessorCount: count, + }; + let hr = unsafe { + WHvSetPartitionProperty( + self.handle, + WHvPartitionPropertyCodeProcessorCount, + &property as *const _ as *const std::ffi::c_void, + std::mem::size_of::() as u32, + ) + }; + check_hresult("WHvSetPartitionProperty(ProcessorCount)", hr) + } + + /// Enable APIC emulation mode (XApic). + pub fn set_local_apic_emulation(&self, enable: bool) -> Result<()> { + let mode = if enable { + WHvX64LocalApicEmulationModeXApic + } else { + WHvX64LocalApicEmulationModeNone + }; + let property = WHV_PARTITION_PROPERTY { + LocalApicEmulationMode: mode, + }; + let hr = unsafe { + WHvSetPartitionProperty( + self.handle, + WHvPartitionPropertyCodeLocalApicEmulationMode, + &property as *const _ as *const std::ffi::c_void, + std::mem::size_of::() as u32, + ) + }; + check_hresult("WHvSetPartitionProperty(LocalApicEmulationMode)", hr) + } + + /// Enable extended VM exits for MSR and/or CPUID interception. + /// + /// Must be called before [`setup()`]. When enabled, guest RDMSR/WRMSR + /// and CPUID instructions cause VM exits instead of being handled + /// by the hypervisor directly. This is required for Linux kernel boot + /// on WHPX — without it, MSR accesses to unrecognized registers cause + /// #GP faults that cascade into triple faults. + pub fn set_extended_vm_exits(&self, msr_exit: bool, cpuid_exit: bool) -> Result<()> { + let mut bits: u64 = 0; + if cpuid_exit { + bits |= EXTENDED_VM_EXITS_CPUID; + } + if msr_exit { + bits |= EXTENDED_VM_EXITS_MSR; + } + let property = WHV_PARTITION_PROPERTY { + ExtendedVmExits: WHV_EXTENDED_VM_EXITS { AsUINT64: bits }, + }; + let hr = unsafe { + WHvSetPartitionProperty( + self.handle, + WHvPartitionPropertyCodeExtendedVmExits, + &property as *const _ as *const std::ffi::c_void, + std::mem::size_of::() as u32, + ) + }; + check_hresult("WHvSetPartitionProperty(ExtendedVmExits)", hr) + } + + /// Finalize the partition configuration. Must be called before creating + /// virtual processors or mapping memory. + pub fn setup(&self) -> Result<()> { + let hr = unsafe { WHvSetupPartition(self.handle) }; + check_hresult("WHvSetupPartition", hr) + } + + /// Map a host memory region into the guest physical address space. + /// + /// # Safety + /// + /// `host_va` must point to a valid memory region of at least `size` bytes + /// that will remain valid for the lifetime of this mapping. + pub unsafe fn map_gpa_range( + &self, + host_va: *mut u8, + guest_pa: u64, + size: u64, + flags: WHV_MAP_GPA_RANGE_FLAGS, + ) -> Result<()> { + let hr = WHvMapGpaRange( + self.handle, + host_va as *const std::ffi::c_void, + guest_pa, + size, + flags, + ); + check_hresult("WHvMapGpaRange", hr) + } + + /// Unmap a guest physical address range. + pub fn unmap_gpa_range(&self, guest_pa: u64, size: u64) -> Result<()> { + let hr = unsafe { WHvUnmapGpaRange(self.handle, guest_pa, size) }; + check_hresult("WHvUnmapGpaRange", hr) + } + + /// Get the raw partition handle (for creating vCPUs etc). + pub fn handle(&self) -> WHV_PARTITION_HANDLE { + self.handle + } + } + + impl Drop for WhpxPartition { + fn drop(&mut self) { + // WHV_PARTITION_HANDLE is isize; 0 means invalid. + if self.handle != 0 { + // SAFETY: We own this partition handle and it's valid. + unsafe { + WHvDeletePartition(self.handle); + } + } + } + } + + /// A WHPX virtual processor (vCPU). + pub struct WhpxVcpu { + partition_handle: WHV_PARTITION_HANDLE, + index: u32, + // Exit context cache — populated by run(), used by skip_instruction()/complete_io_in(). + exit_rip: Cell, + exit_instruction_len: Cell, + exit_rax: Cell, + // MMIO read completion cache — populated by run() on MMIO read exits. + exit_mmio_gpr_index: Cell>, + exit_mmio_access_size: Cell, + } + + // SAFETY: Each vCPU is operated on by a single thread at a time. + // The WHPX API permits calling WHvRunVirtualProcessor from a dedicated thread. + // Sync is needed because std::thread::scope borrows &WhpxVcpu across threads, + // but each &WhpxVcpu is only accessed by its dedicated vCPU thread — no + // concurrent access to the Cell fields occurs. + unsafe impl Send for WhpxVcpu {} + unsafe impl Sync for WhpxVcpu {} + + impl WhpxVcpu { + /// Create a new virtual processor in the given partition. + pub fn new(partition: &WhpxPartition, index: u32) -> Result { + let hr = unsafe { WHvCreateVirtualProcessor(partition.handle(), index, 0) }; + check_hresult("WHvCreateVirtualProcessor", hr)?; + + Ok(WhpxVcpu { + partition_handle: partition.handle(), + index, + exit_rip: Cell::new(0), + exit_instruction_len: Cell::new(0), + exit_rax: Cell::new(0), + exit_mmio_gpr_index: Cell::new(None), + exit_mmio_access_size: Cell::new(0), + }) + } + + /// Get standard (general-purpose) registers. + pub fn get_registers(&self) -> Result { + let register_names = [ + WHvX64RegisterRax, + WHvX64RegisterRbx, + WHvX64RegisterRcx, + WHvX64RegisterRdx, + WHvX64RegisterRsi, + WHvX64RegisterRdi, + WHvX64RegisterRsp, + WHvX64RegisterRbp, + WHvX64RegisterR8, + WHvX64RegisterR9, + WHvX64RegisterR10, + WHvX64RegisterR11, + WHvX64RegisterR12, + WHvX64RegisterR13, + WHvX64RegisterR14, + WHvX64RegisterR15, + WHvX64RegisterRip, + WHvX64RegisterRflags, + ]; + + // Use heap allocation (Vec) instead of stack arrays — WHPX on some + // Win10 builds crashes with stack-allocated WHV_REGISTER_VALUE arrays + // (likely a 16-byte alignment issue on the stack). + let mut values: Vec = + vec![zeroed_reg_value(); register_names.len()]; + + let hr = unsafe { + WHvGetVirtualProcessorRegisters( + self.partition_handle, + self.index, + register_names.as_ptr(), + register_names.len() as u32, + values.as_mut_ptr(), + ) + }; + check_hresult("WHvGetVirtualProcessorRegisters", hr)?; + + // SAFETY: We requested 64-bit register values, so Reg64 is the valid union field. + unsafe { + Ok(StandardRegisters { + rax: read_reg64(&values[0]), + rbx: read_reg64(&values[1]), + rcx: read_reg64(&values[2]), + rdx: read_reg64(&values[3]), + rsi: read_reg64(&values[4]), + rdi: read_reg64(&values[5]), + rsp: read_reg64(&values[6]), + rbp: read_reg64(&values[7]), + r8: read_reg64(&values[8]), + r9: read_reg64(&values[9]), + r10: read_reg64(&values[10]), + r11: read_reg64(&values[11]), + r12: read_reg64(&values[12]), + r13: read_reg64(&values[13]), + r14: read_reg64(&values[14]), + r15: read_reg64(&values[15]), + rip: read_reg64(&values[16]), + rflags: read_reg64(&values[17]), + }) + } + } + + /// Set standard (general-purpose) registers. + pub fn set_registers(&self, regs: &StandardRegisters) -> Result<()> { + let register_names = [ + WHvX64RegisterRax, + WHvX64RegisterRbx, + WHvX64RegisterRcx, + WHvX64RegisterRdx, + WHvX64RegisterRsi, + WHvX64RegisterRdi, + WHvX64RegisterRsp, + WHvX64RegisterRbp, + WHvX64RegisterR8, + WHvX64RegisterR9, + WHvX64RegisterR10, + WHvX64RegisterR11, + WHvX64RegisterR12, + WHvX64RegisterR13, + WHvX64RegisterR14, + WHvX64RegisterR15, + WHvX64RegisterRip, + WHvX64RegisterRflags, + ]; + + // Use heap allocation — see get_registers() comment on alignment. + let values: Vec = vec![ + reg64(regs.rax), + reg64(regs.rbx), + reg64(regs.rcx), + reg64(regs.rdx), + reg64(regs.rsi), + reg64(regs.rdi), + reg64(regs.rsp), + reg64(regs.rbp), + reg64(regs.r8), + reg64(regs.r9), + reg64(regs.r10), + reg64(regs.r11), + reg64(regs.r12), + reg64(regs.r13), + reg64(regs.r14), + reg64(regs.r15), + reg64(regs.rip), + reg64(regs.rflags), + ]; + + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + register_names.as_ptr(), + register_names.len() as u32, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters", hr) + } + + /// Get special/system registers (segments, control registers, EFER). + pub fn get_special_registers(&self) -> Result { + let register_names = [ + // Segment registers + WHvX64RegisterCs, + WHvX64RegisterDs, + WHvX64RegisterEs, + WHvX64RegisterFs, + WHvX64RegisterGs, + WHvX64RegisterSs, + WHvX64RegisterTr, + WHvX64RegisterLdtr, + // Descriptor table registers + WHvX64RegisterGdtr, + WHvX64RegisterIdtr, + // Control registers + WHvX64RegisterCr0, + WHvX64RegisterCr2, + WHvX64RegisterCr3, + WHvX64RegisterCr4, + WHvX64RegisterEfer, + ]; + + // Use heap allocation — see get_registers() comment on alignment. + let mut values: Vec = + vec![zeroed_reg_value(); register_names.len()]; + + let hr = unsafe { + WHvGetVirtualProcessorRegisters( + self.partition_handle, + self.index, + register_names.as_ptr(), + register_names.len() as u32, + values.as_mut_ptr(), + ) + }; + check_hresult("WHvGetVirtualProcessorRegisters(special)", hr)?; + + // Helper to extract segment register from WHV_REGISTER_VALUE. + // SAFETY: Segment register values are stored in the Segment field of the union. + let seg = |v: &WHV_REGISTER_VALUE| { + let s = unsafe { &v.Segment }; + super::super::types::SegmentRegister { + base: s.Base, + limit: s.Limit, + selector: s.Selector, + // WHV_X64_SEGMENT_REGISTER_0 is a union with an Attributes field. + access_rights: unsafe { s.Anonymous.Attributes }, + } + }; + + // SAFETY: Table register values are stored in the Table field of the union. + let table = |v: &WHV_REGISTER_VALUE| { + let t = unsafe { &v.Table }; + super::super::types::DescriptorTable { + base: t.Base, + limit: t.Limit, + } + }; + + Ok(SpecialRegisters { + cs: seg(&values[0]), + ds: seg(&values[1]), + es: seg(&values[2]), + fs: seg(&values[3]), + gs: seg(&values[4]), + ss: seg(&values[5]), + tr: seg(&values[6]), + ldt: seg(&values[7]), + gdt: table(&values[8]), + idt: table(&values[9]), + cr0: unsafe { read_reg64(&values[10]) }, + cr2: unsafe { read_reg64(&values[11]) }, + cr3: unsafe { read_reg64(&values[12]) }, + cr4: unsafe { read_reg64(&values[13]) }, + efer: unsafe { read_reg64(&values[14]) }, + }) + } + + /// Set special/system registers. + pub fn set_special_registers(&self, sregs: &SpecialRegisters) -> Result<()> { + let register_names = [ + WHvX64RegisterCs, + WHvX64RegisterDs, + WHvX64RegisterEs, + WHvX64RegisterFs, + WHvX64RegisterGs, + WHvX64RegisterSs, + WHvX64RegisterTr, + WHvX64RegisterLdtr, + WHvX64RegisterGdtr, + WHvX64RegisterIdtr, + WHvX64RegisterCr0, + WHvX64RegisterCr2, + WHvX64RegisterCr3, + WHvX64RegisterCr4, + WHvX64RegisterEfer, + ]; + + // Helper to build WHV_REGISTER_VALUE for a segment register. + let seg_val = |s: &super::super::types::SegmentRegister| WHV_REGISTER_VALUE { + Segment: WHV_X64_SEGMENT_REGISTER { + Base: s.base, + Limit: s.limit, + Selector: s.selector, + Anonymous: WHV_X64_SEGMENT_REGISTER_0 { + Attributes: s.access_rights, + }, + }, + }; + + // Helper to build WHV_REGISTER_VALUE for a table register. + let table_val = |t: &super::super::types::DescriptorTable| WHV_REGISTER_VALUE { + Table: WHV_X64_TABLE_REGISTER { + Pad: [0u16; 3], + Base: t.base, + Limit: t.limit, + }, + }; + + // Use heap allocation — see get_registers() comment on alignment. + let values: Vec = vec![ + seg_val(&sregs.cs), + seg_val(&sregs.ds), + seg_val(&sregs.es), + seg_val(&sregs.fs), + seg_val(&sregs.gs), + seg_val(&sregs.ss), + seg_val(&sregs.tr), + seg_val(&sregs.ldt), + table_val(&sregs.gdt), + table_val(&sregs.idt), + reg64(sregs.cr0), + reg64(sregs.cr2), + reg64(sregs.cr3), + reg64(sregs.cr4), + reg64(sregs.efer), + ]; + + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + register_names.as_ptr(), + register_names.len() as u32, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(special)", hr) + } + + /// Run the virtual processor until a VM exit occurs. + /// + /// After an I/O exit, call [`skip_instruction`] (for writes) or + /// [`complete_io_in`] (for reads) to resume execution. + pub fn run(&self) -> Result { + let mut exit_context: WHV_RUN_VP_EXIT_CONTEXT = unsafe { std::mem::zeroed() }; + let hr = unsafe { + WHvRunVirtualProcessor( + self.partition_handle, + self.index, + &mut exit_context as *mut _ as *mut std::ffi::c_void, + std::mem::size_of::() as u32, + ) + }; + check_hresult("WHvRunVirtualProcessor", hr).map_err(|e| { + log::error!( + "WHvRunVirtualProcessor FAILED: {:?} (HRESULT=0x{:08X})", + e, + hr as u32 + ); + e + })?; + + // Cache RIP from the VP context for skip_instruction/complete_io_in. + self.exit_rip.set(exit_context.VpContext.Rip); + + // Extract instruction length from VpContext. + // WHV_VP_EXIT_CONTEXT layout: [ExecutionState:2][InstructionLength(4bits)|Cr8(4bits):1]... + // InstructionLength is at byte offset 2, lower 4 bits. + // SAFETY: VpContext is a repr(C) struct; byte access at offset 2 is within bounds. + let vp_instruction_len = unsafe { + let vp_bytes = &exit_context.VpContext as *const _ as *const u8; + *vp_bytes.add(2) & 0xF + }; + self.exit_instruction_len.set(vp_instruction_len); + + // WHV_RUN_VP_EXIT_REASON is i32; use if/else chain to avoid + // warnings about lowercase constant names in match patterns. + let reason = exit_context.ExitReason; + if reason == WHvRunVpExitReasonX64IoPortAccess { + // SAFETY: ExitReason is IoPortAccess, so the IoPortAccess union field is valid. + let io = unsafe { &exit_context.Anonymous.IoPortAccess }; + let port = io.PortNumber; + let size = io_access_size(&io.AccessInfo); + let is_write = io_access_is_write(&io.AccessInfo); + + self.exit_rax.set(io.Rax); + + if is_write { + let data = io.Rax as u32; + Ok(VcpuExit::IoOut { port, size, data }) + } else { + Ok(VcpuExit::IoIn { port, size }) + } + } else if reason == WHvRunVpExitReasonMemoryAccess { + // SAFETY: ExitReason is MemoryAccess, so the MemoryAccess union field is valid. + let mem_ctx = unsafe { &exit_context.Anonymous.MemoryAccess }; + let address = mem_ctx.Gpa; + let access_type = mem_access_type(&mem_ctx.AccessInfo); + let is_write = access_type == 1; + + // Decode the faulting instruction to get access size and write data. + let byte_count = mem_ctx.InstructionByteCount as usize; + let insn_bytes = &mem_ctx.InstructionBytes[..byte_count.min(16)]; + let regs = self.get_registers().map_err(|e| { + log::error!("MMIO get_registers FAILED at GPA 0x{:x}: {:?}", address, e); + e + })?; + let insn = match super::super::insn::decode_mmio_insn(insn_bytes, ®s) { + Ok(insn) => insn, + Err(e) => { + log::error!( + "MMIO decode FAILED at GPA 0x{:x}: {:?}, bytes: {:02x?}, is_write={}", + address, + e, + insn_bytes, + is_write + ); + eprintln!( + "[WHPX] MMIO decode FAILED at GPA 0x{:x}, bytes: {:02x?}", + address, insn_bytes + ); + return Err(e); + } + }; + + self.exit_instruction_len.set(insn.len); + self.exit_mmio_gpr_index.set(insn.gpr_index); + self.exit_mmio_access_size.set(insn.access_size); + + if is_write { + Ok(VcpuExit::MmioWrite { + address, + size: insn.access_size, + data: insn.data, + }) + } else { + Ok(VcpuExit::MmioRead { + address, + size: insn.access_size, + }) + } + } else if reason == WHvRunVpExitReasonX64InterruptWindow { + Ok(VcpuExit::InterruptWindow) + } else if reason == WHvRunVpExitReasonX64Halt { + Ok(VcpuExit::Halt) + } else if reason == WHvRunVpExitReasonCanceled { + Ok(VcpuExit::Cancelled) + } else if reason == WHvRunVpExitReasonX64MsrAccess { + // SAFETY: ExitReason is MsrAccess, so the MsrAccess union field is valid. + let msr_ctx = unsafe { &exit_context.Anonymous.MsrAccess }; + let is_write = msr_access_is_write(&msr_ctx.AccessInfo); + Ok(VcpuExit::MsrAccess { + msr_number: msr_ctx.MsrNumber, + is_write, + rax: msr_ctx.Rax, + rdx: msr_ctx.Rdx, + }) + } else if reason == WHvRunVpExitReasonX64Cpuid { + // SAFETY: ExitReason is CpuidAccess, so the CpuidAccess union field is valid. + let cpuid_ctx = unsafe { &exit_context.Anonymous.CpuidAccess }; + Ok(VcpuExit::CpuidAccess { + rax: cpuid_ctx.Rax, + rcx: cpuid_ctx.Rcx, + default_rax: cpuid_ctx.DefaultResultRax, + default_rbx: cpuid_ctx.DefaultResultRbx, + default_rcx: cpuid_ctx.DefaultResultRcx, + default_rdx: cpuid_ctx.DefaultResultRdx, + }) + } else if reason == WHvRunVpExitReasonUnrecoverableException { + Ok(VcpuExit::UnrecoverableException) + } else if reason == WHvRunVpExitReasonNone { + Ok(VcpuExit::Shutdown) + } else { + Ok(VcpuExit::Unknown(reason as u32)) + } + } + + /// Get cached exit context info (for diagnostics and testing). + /// + /// Returns `(rip, instruction_len, rax)` from the last VM exit. + pub fn exit_info(&self) -> (u64, u8, u64) { + ( + self.exit_rip.get(), + self.exit_instruction_len.get(), + self.exit_rax.get(), + ) + } + + /// Advance RIP past the last intercepted instruction. + /// + /// Call after handling [`VcpuExit::IoOut`] or [`VcpuExit::MmioWrite`] + /// to resume execution at the next instruction. + pub fn skip_instruction(&self) -> Result<()> { + let instruction_len = self.exit_instruction_len.get(); + // Read current RIP from registers (guaranteed correct). + let regs = self.get_registers()?; + let new_rip = regs.rip + instruction_len as u64; + let names = [WHvX64RegisterRip]; + let values: Vec = vec![reg64(new_rip)]; + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 1, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(skip)", hr) + } + + /// Complete an MSR read (RDMSR): inject result into RAX:RDX and advance RIP. + /// + /// For RDMSR, the 64-bit result is split: low 32 bits in EAX, high 32 in EDX. + /// Call after handling [`VcpuExit::MsrAccess`] where `is_write == false`. + pub fn complete_msr_read(&self, value: u64) -> Result<()> { + let instruction_len = self.exit_instruction_len.get(); + let regs = self.get_registers()?; + let new_rip = regs.rip + instruction_len as u64; + let new_rax = value & 0xFFFF_FFFF; + let new_rdx = value >> 32; + + let names = [WHvX64RegisterRip, WHvX64RegisterRax, WHvX64RegisterRdx]; + let values: Vec = + vec![reg64(new_rip), reg64(new_rax), reg64(new_rdx)]; + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 3, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(msr_read)", hr) + } + + /// Complete a CPUID exit: inject results into RAX/RBX/RCX/RDX and advance RIP. + /// + /// Call after handling [`VcpuExit::CpuidAccess`]. + pub fn complete_cpuid(&self, rax: u64, rbx: u64, rcx: u64, rdx: u64) -> Result<()> { + let instruction_len = self.exit_instruction_len.get(); + let regs = self.get_registers()?; + let new_rip = regs.rip + instruction_len as u64; + + let names = [ + WHvX64RegisterRip, + WHvX64RegisterRax, + WHvX64RegisterRbx, + WHvX64RegisterRcx, + WHvX64RegisterRdx, + ]; + let values: Vec = vec![ + reg64(new_rip), + reg64(rax), + reg64(rbx), + reg64(rcx), + reg64(rdx), + ]; + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 5, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(cpuid)", hr) + } + + /// Complete an I/O IN operation: inject data into RAX and advance RIP. + /// + /// Preserves upper RAX bits based on the I/O access size: + /// - size 1: modifies AL only (bits 0-7) + /// - size 2: modifies AX only (bits 0-15) + /// - size 4: modifies EAX (bits 0-31) + /// + /// Call after handling [`VcpuExit::IoIn`]. + pub fn complete_io_in(&self, data: u32, size: u8) -> Result<()> { + let instruction_len = self.exit_instruction_len.get(); + // Read current registers (RIP and RAX guaranteed correct). + let regs = self.get_registers()?; + let new_rip = regs.rip + instruction_len as u64; + let mask: u64 = match size { + 1 => 0xFF, + 2 => 0xFFFF, + 4 => 0xFFFF_FFFF, + _ => 0xFF, + }; + let new_rax = (regs.rax & !mask) | (data as u64 & mask); + + let names = [WHvX64RegisterRip, WHvX64RegisterRax]; + let values: Vec = vec![reg64(new_rip), reg64(new_rax)]; + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 2, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(io_in)", hr) + } + + /// Complete an MMIO read: inject data into the destination GPR and advance RIP. + /// + /// The destination register and access size were cached during [`run()`]. + /// Data is zero-extended into the register per x86 semantics: + /// - 1-byte: zero-extends to 64 bits (MOVZX) or writes AL (MOV) + /// - 2-byte: zero-extends to 64 bits (MOVZX) or writes AX (MOV) + /// - 4-byte: zero-extends to 64 bits (x86-64 implicit) + /// - 8-byte: writes full 64-bit register + /// + /// Call after handling [`VcpuExit::MmioRead`]. + pub fn complete_mmio_read(&self, data: u64) -> Result<()> { + let gpr_index = match self.exit_mmio_gpr_index.get() { + Some(idx) => idx, + None => { + return Err(super::super::error::WkrunError::Vcpu( + "complete_mmio_read: no cached GPR index".into(), + )) + } + }; + let access_size = self.exit_mmio_access_size.get(); + let insn_len = self.exit_instruction_len.get(); + + let mut regs = self.get_registers()?; + let new_rip = regs.rip + insn_len as u64; + + // Mask data to access size. For 4-byte writes, x86-64 zero-extends + // the 32-bit result into the full 64-bit register. + let masked = match access_size { + 1 => data & 0xFF, + 2 => data & 0xFFFF, + 4 => data & 0xFFFF_FFFF, + _ => data, + }; + + // Write into the destination GPR. + match gpr_index { + 0 => regs.rax = masked, + 1 => regs.rcx = masked, + 2 => regs.rdx = masked, + 3 => regs.rbx = masked, + 4 => regs.rsp = masked, + 5 => regs.rbp = masked, + 6 => regs.rsi = masked, + 7 => regs.rdi = masked, + 8 => regs.r8 = masked, + 9 => regs.r9 = masked, + 10 => regs.r10 = masked, + 11 => regs.r11 = masked, + 12 => regs.r12 = masked, + 13 => regs.r13 = masked, + 14 => regs.r14 = masked, + 15 => regs.r15 = masked, + _ => {} + } + + regs.rip = new_rip; + self.set_registers(®s) + } + + /// Inject an external hardware interrupt into the vCPU. + /// + /// The interrupt is delivered on the next `run()` call. The caller + /// must ensure `RFLAGS.IF = 1` before calling this (use + /// [`interrupts_enabled`] to check, and [`request_interrupt_window`] + /// if interrupts are currently disabled). + pub fn inject_interrupt(&self, vector: u8) -> Result<()> { + // Build WHV_X64_PENDING_INTERRUPTION_REGISTER as u64: + // Bit 0: InterruptionPending = 1 + // Bits 1-3: InterruptionType = 0 (external interrupt) + // Bit 4: DeliverErrorCode = 0 + // Bits 16-31: InterruptionVector = vector + let pending: u64 = 1 | ((vector as u64) << 16); + + let names = [WHvRegisterPendingInterruption]; + let values: Vec = vec![WHV_REGISTER_VALUE { Reg64: pending }]; + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 1, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(inject_interrupt)", hr) + } + + /// Deliver an interrupt via the partition-level WHvRequestInterrupt API. + /// + /// Unlike [`inject_interrupt`] (which sets WHvRegisterPendingInterruption), + /// this API delivers the interrupt at the partition level and — critically — + /// resets the vCPU's HLT suspend state on platforms where + /// WHvRegisterInternalActivityState is inaccessible (Win10). + /// + /// Uses Fixed delivery, edge-triggered, physical destination mode. + /// Returns Ok(true) if the interrupt was delivered, Ok(false) if the + /// API returned an error (caller should fall back to inject_interrupt). + pub fn request_interrupt(&self, vector: u8) -> Result { + // WHV_INTERRUPT_CONTROL layout (from Hyper-V TLFS / Windows SDK): + // _bitfield (u64): + // bits 0-31: InterruptType (u32) — 0 = Fixed + // bit 32: LevelTriggered — 0 = edge + // bit 33: LogicalDestinationMode — 0 = physical + // bits 34-63: Reserved (0) + // Destination (u32): target vCPU index + // Vector (u32): interrupt vector + let interrupt = WHV_INTERRUPT_CONTROL { + _bitfield: 0, // Fixed=0, edge-triggered=0, physical=0 + Destination: self.index, + Vector: vector as u32, + }; + let hr = unsafe { + WHvRequestInterrupt( + self.partition_handle, + &interrupt, + std::mem::size_of::() as u32, + ) + }; + if hr == 0 { + Ok(true) + } else { + // Log at warn level (not debug) so it's visible at RUST_LOG=info. + // This is a critical diagnostic — if WHvRequestInterrupt fails, + // the vCPU may not wake from HLT on Win10. + log::warn!( + "WHvRequestInterrupt failed: HRESULT=0x{:08X}, vector={}", + hr as u32, + vector + ); + Ok(false) + } + } + + /// Check if the guest has interrupts enabled (RFLAGS.IF = 1). + pub fn interrupts_enabled(&self) -> Result { + let regs = self.get_registers()?; + Ok(regs.rflags & (1 << 9) != 0) + } + + /// Check if there is a pending interruption that hasn't been delivered yet. + /// + /// Returns `true` if `WHvRegisterPendingInterruption` bit 0 + /// (InterruptionPending) is set. A new injection must NOT be + /// attempted while a previous one is still pending — doing so + /// would overwrite the old interrupt, leaving its PIC ISR bit + /// permanently stuck. + pub fn has_pending_interruption(&self) -> Result { + let names = [WHvRegisterPendingInterruption]; + let mut values: Vec = vec![zeroed_reg_value(); 1]; + let hr = unsafe { + WHvGetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 1, + values.as_mut_ptr(), + ) + }; + check_hresult("WHvGetVirtualProcessorRegisters(pending_interruption)", hr)?; + let pending = unsafe { values[0].Reg64 }; + Ok(pending & 1 != 0) + } + + /// Request an interrupt window exit. + /// + /// The next `run()` call will exit with [`VcpuExit::InterruptWindow`] + /// as soon as the guest enables interrupts (RFLAGS.IF = 1). + pub fn request_interrupt_window(&self) -> Result<()> { + // WHV_X64_DELIVERABILITY_NOTIFICATIONS_REGISTER: + // Bit 1: InterruptNotification = 1 + let notifications: u64 = 1 << 1; + + let names = [WHvX64RegisterDeliverabilityNotifications]; + let values: Vec = vec![WHV_REGISTER_VALUE { + Reg64: notifications, + }]; + let hr = unsafe { + WHvSetVirtualProcessorRegisters( + self.partition_handle, + self.index, + names.as_ptr(), + 1, + values.as_ptr(), + ) + }; + check_hresult("WHvSetVirtualProcessorRegisters(interrupt_window)", hr) + } + + /// Cancel a running vCPU (causes it to exit with Cancelled). + pub fn cancel(&self) -> Result<()> { + let hr = unsafe { WHvCancelRunVirtualProcessor(self.partition_handle, self.index, 0) }; + check_hresult("WHvCancelRunVirtualProcessor", hr) + } + + /// Get the vCPU index. + pub fn index(&self) -> u32 { + self.index + } + + /// Create a lightweight canceller that can be sent to another thread. + pub fn canceller(&self) -> VcpuCanceller { + VcpuCanceller { + partition_handle: self.partition_handle, + index: self.index, + } + } + + /// Configure AP initial register state after receiving SIPI. + /// + /// Sets the AP into real mode with CS:IP pointing to the SIPI trampoline: + /// - CS.base = sipi_vector * 0x1000, CS.selector = sipi_vector * 0x100 + /// - IP = 0 + /// - DL = APIC ID (Linux convention for AP identification) + /// - All other regs = 0 / default real mode values + pub fn set_ap_initial_regs(&self, sipi_vector: u8, apic_id: u8) -> Result<()> { + use super::super::types::{SegmentRegister, StandardRegisters}; + + let cs_base = (sipi_vector as u64) * 0x1000; + let cs_selector = (sipi_vector as u16) * 0x100; + + let regs = StandardRegisters { + rdx: apic_id as u64, // Linux uses DL for APIC ID on AP startup + rflags: 0x2, // x86 requires RFLAGS bit 1 always set + ..Default::default() + }; + + // Read existing special registers to preserve WHPX defaults for TR, LDT, + // GDT, IDT. WHPX requires valid access_rights for these even in real mode; + // overwriting them with zeros causes WHvRunVpExitReasonInvalidVpRegisterValue + // (exit reason 5). + let mut sregs = self.get_special_registers()?; + + sregs.cs = SegmentRegister { + base: cs_base, + limit: 0xFFFF, + selector: cs_selector, + access_rights: 0x9B, // present, code, readable, accessed + }; + let data_seg = SegmentRegister { + base: 0, + limit: 0xFFFF, + selector: 0, + access_rights: 0x93, // present, data, writable, accessed + }; + sregs.ds = data_seg; + sregs.es = data_seg; + sregs.fs = data_seg; + sregs.gs = data_seg; + sregs.ss = data_seg; + sregs.cr0 = 0x10; // ET (Extension Type) — required for real mode on x86 + + self.set_registers(®s)?; + self.set_special_registers(&sregs)?; + Ok(()) + } + } + + /// Lightweight handle for cancelling a running vCPU from another thread. + /// + /// Only supports the cancel operation — safe to use from a timer thread + /// to preempt the vCPU for interrupt delivery. + #[derive(Clone)] + pub struct VcpuCanceller { + partition_handle: WHV_PARTITION_HANDLE, + index: u32, + } + + // SAFETY: WHvCancelRunVirtualProcessor is documented as safe to call + // from any thread while the vCPU is running. + unsafe impl Send for VcpuCanceller {} + unsafe impl Sync for VcpuCanceller {} + + impl VcpuCanceller { + /// Cancel the vCPU run, causing it to exit with VcpuExit::Cancelled. + pub fn cancel(&self) -> Result<()> { + let hr = unsafe { WHvCancelRunVirtualProcessor(self.partition_handle, self.index, 0) }; + check_hresult("WHvCancelRunVirtualProcessor", hr) + } + } + + impl Drop for WhpxVcpu { + fn drop(&mut self) { + // SAFETY: We own this vCPU and the partition handle is still valid + // (guaranteed by the borrow lifetime in practice, but we store a raw handle). + unsafe { + WHvDeleteVirtualProcessor(self.partition_handle, self.index); + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_standard_registers_default() { + let regs = StandardRegisters::default(); + assert_eq!(regs.rax, 0); + assert_eq!(regs.rip, 0); + assert_eq!(regs.rflags, 0); + } + + #[test] + fn test_special_registers_default() { + let sregs = SpecialRegisters::default(); + assert_eq!(sregs.cr0, 0); + assert_eq!(sregs.cr3, 0); + assert_eq!(sregs.efer, 0); + assert_eq!(sregs.cs.selector, 0); + } + + #[test] + fn test_segment_register_construction() { + let seg = super::super::super::types::SegmentRegister { + base: 0, + limit: 0xFFFF_FFFF, + selector: 0x10, + access_rights: 0xC093, // data segment + }; + assert_eq!(seg.selector, 0x10); + assert_eq!(seg.access_rights, 0xC093); + } + } +} + +#[cfg(target_os = "windows")] +pub use imp::*;