* Add helper functions for;

+ MSFT_NetIPAddress
   + MSFT_NetAdapter
     + Disable, Enable, Rename support
   + Win32_NetworkAdapterConfiguration
 * Add unit tests for all of the above.

PiperOrigin-RevId: 785509372
This commit is contained in:
Matthew Oliver
2025-07-21 11:23:48 -07:00
committed by Copybara-Service
parent 5c26fe5353
commit 0ed520ba4a
7 changed files with 1420 additions and 0 deletions

View File

@@ -0,0 +1,225 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package netw
import (
"fmt"
"github.com/google/deck"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
// IPAddress represents a MSFT_NetIPAddress object.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/nettcpipprov/msft-netipaddress
type IPAddress struct {
AddressFamily uint16
AddressState uint16
InterfaceAlias string
InterfaceIndex uint32
IPAddress string
PreferredLifetime string
PrefixOrigin uint16
SkipAsSource bool
Store uint8
SuffixOrigin uint16
Type uint8
ValidLifetime string
// handle is the internal ole handle
handle *ole.IDispatch
}
// CreateIPAddressOptions represents the options for creating an IP address.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/nettcpipprov/create-msft-netipaddress
type CreateIPAddressOptions struct {
InterfaceIndex uint32
InterfaceAlias string
IPAddress string
AddressFamily uint16
PrefixLength uint8
Type uint8
PrefixOrigin uint8
SuffixOrigin uint8
AddressState uint16
ValidLifetime string // CIM_DATETIME
PreferredLifetime string // CIM_DATETIME
SkipAsSource bool
DefaultGateway string
PolicyStore string
PassThru bool
}
// IPAddressSet contains one or more IPAddresses.
type IPAddressSet struct {
IPAddresses []IPAddress
}
// GetIPAddresses returns a IPAddresses struct.
//
// Get all IP addresses:
//
// svc.GetIPAddresses("")
//
// To get specific IP addresses, provide a valid WMI query filter string, for example:
//
// svc.GetIPAddresses("WHERE IPAddress='192.168.1.1'")
func (svc Service) GetIPAddresses(filter string) (IPAddressSet, error) {
var ipset IPAddressSet
query := "SELECT * FROM MSFT_NetIPAddress"
if filter != "" {
query = fmt.Sprintf("%s %s", query, filter)
}
deck.InfoA(query).With(deck.V(1)).Go()
raw, err := oleutil.CallMethod(svc.wmiSvc, "ExecQuery", query)
if err != nil {
return ipset, fmt.Errorf("ExecQuery(%s): %w", query, err)
}
result := raw.ToIDispatch()
defer result.Release()
countVar, err := oleutil.GetProperty(result, "Count")
if err != nil {
return ipset, fmt.Errorf("oleutil.GetProperty(Count): %w", err)
}
count := int(countVar.Val)
for i := 0; i < count; i++ {
ipresult := IPAddress{}
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
if err != nil {
return ipset, fmt.Errorf("oleutil.CallMethod(ItemIndex, %d): %w", i, err)
}
ipresult.handle = itemRaw.ToIDispatch()
if err := ipresult.Query(); err != nil {
return ipset, fmt.Errorf("ipresult.Query(): %w", err)
}
ipset.IPAddresses = append(ipset.IPAddresses, ipresult)
}
return ipset, nil
}
// Close releases the handle to the IP address.
func (ip *IPAddress) Close() {
if ip.handle != nil {
ip.handle.Release()
}
}
// Query reads and populates the IP address state from WMI.
func (ip *IPAddress) Query() error {
if ip.handle == nil {
return fmt.Errorf("invalid handle")
}
// All the non-string/slice properties
for _, prop := range [][]any{
{"AddressFamily", &ip.AddressFamily},
{"InterfaceIndex", &ip.InterfaceIndex},
{"PrefixOrigin", &ip.PrefixOrigin},
{"SkipAsSource", &ip.SkipAsSource},
{"SuffixOrigin", &ip.SuffixOrigin},
{"Type", &ip.Type},
{"Store", &ip.Store},
{"AddressState", &ip.AddressState},
} {
name, ok := prop[0].(string)
if !ok {
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
}
val, err := oleutil.GetProperty(ip.handle, name)
if err != nil {
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
}
if val.VT != ole.VT_NULL {
if err := AssignVariant(val.Value(), prop[1]); err != nil {
deck.Warningf("AssignVariant(%s): %v", name, err)
}
}
}
// String properties
for _, prop := range [][]any{
{"InterfaceAlias", &ip.InterfaceAlias},
{"IPAddress", &ip.IPAddress},
{"ValidLifetime", &ip.ValidLifetime},
{"PreferredLifetime", &ip.PreferredLifetime},
} {
name, ok := prop[0].(string)
if !ok {
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
}
val, err := oleutil.GetProperty(ip.handle, name)
if err != nil {
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
}
if val.VT != ole.VT_NULL {
*(prop[1].(*string)) = val.ToString()
}
}
return nil
}
// IPOutput represents the output of the Create method.
type IPOutput struct{}
// Create creates the IP address on the current instance.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/nettcpipprov/create-msft-netipaddress
func (ip *IPAddress) Create(opts CreateIPAddressOptions) (IPOutput, error) {
ipset := IPOutput{}
return ipset, fmt.Errorf("not implemented")
// var createdObject ole.VARIANT
// ole.VariantInit(&createdObject)
// Parameters must be passed in the order defined by the WMI method signature.
// res, err := oleutil.CallMethod(ip.handle, "Create",
// opts.InterfaceIndex,
// opts.InterfaceAlias,
// opts.IPAddress,
// opts.AddressFamily,
// opts.PrefixLength,
// opts.Type,
// opts.PrefixOrigin,
// opts.SuffixOrigin,
// opts.AddressState,
// opts.ValidLifetime,
// opts.PreferredLifetime,
// opts.SkipAsSource,
// opts.DefaultGateway,
// opts.PolicyStore,
// opts.PassThru,
// &createdObject, // output
// )
// if err != nil {
// return ipset, fmt.Errorf("Create: %w", err)
// }
// if val, ok := res.Value().(int32); val != 0 || !ok {
// return ipset, fmt.Errorf("error code returned during create: %d", val)
// }
// ip.handle = createdObject.ToIDispatch()
// return ipset, nil
}

View File

@@ -0,0 +1,84 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package netw
import (
"fmt"
"testing"
)
func TestGetIPAddresses(t *testing.T) {
n, err := Connect()
if err != nil {
t.Fatalf("failed to connect to network: %v", err)
}
defer n.Close()
// First get all IPs to find a valid address for a filter test.
allIPs, err := n.GetIPAddresses("")
if err != nil {
t.Fatalf("Initial GetIPAddresses() failed: %v", err)
}
if len(allIPs.IPAddresses) == 0 {
t.Skip("No IP addresses found, skipping test.")
}
ipToFilter := allIPs.IPAddresses[0]
testCases := []struct {
name string
filter string
wantErr bool
}{
{
name: "no filter",
filter: "",
},
{
name: "filter by ip address",
filter: fmt.Sprintf("WHERE IPAddress = '%s'", ipToFilter.IPAddress),
},
{
name: "bad filter",
filter: "WHERE BadFilter = 'true'",
wantErr: true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ips, err := n.GetIPAddresses(tt.filter)
if (err != nil) != tt.wantErr {
t.Errorf("GetIPAddresses() error = %v, wantErr = %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Expected error, nothing more to check.
}
if len(ips.IPAddresses) == 0 {
t.Fatal("got 0 ip addresses, want at least 1")
}
for _, ip := range ips.IPAddresses {
if ip.IPAddress == "" {
t.Error("ip has an empty IPAddress")
}
if ip.InterfaceIndex == 0 {
t.Errorf("ip %q has a zero InterfaceIndex", ip.IPAddress)
}
}
})
}
}

View File

@@ -0,0 +1,485 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package netw
import (
"fmt"
"github.com/google/deck"
"golang.org/x/sys/windows"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
// NetAdapter represents a MSFT_NetAdapter object. It's important to note that some fields are read-only.
//
// Ref: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
type NetAdapter struct {
// The name of the network adapter. This property is inherited from CIM_ManagedSystemElement.
Name string
// Current status of the object. This property is inherited from CIM_ManagedSystemElement.
Status string
// The availability and status of the device. This property is inherited from CIM_LogicalDevice.
Availability uint16
// If TRUE, the device is using a user-defined configuration. This property is inherited from CIM_LogicalDevice.
ConfigManagerUserConfig bool
// The name of the class or subclass used in the creation of an instance. This property is inherited from CIM_LogicalDevice.
CreationClassName string
// The address or other identifying information for the network adapter. This property is inherited from CIM_LogicalDevice.
DeviceID string
// If TRUE, the error reported in the LastErrorCode property is now cleared. This property is inherited from CIM_LogicalDevice.
ErrorCleared bool
// A string that provides more information about the error recorded in the LastErrorCode property and information about any corrective actions that can be taken. This property is inherited from CIM_LogicalDevice.
ErrorDescription string
// The last error code reported by the logical device. This property is inherited from CIM_LogicalDevice.
LastErrorCode uint32
// The Plug and Play device identifier of the logical device. This property is inherited from CIM_LogicalDevice.
PNPDeviceID string
// An array of the specific power-related capabilities of a logical device. This property is inherited from CIM_LogicalDevice.
PowerManagementCapabilities []uint16
// If TRUE, the device can be power managed. This property is inherited from CIM_LogicalDevice.
PowerManagementSupported bool
// The status of the logical device. This property is inherited from CIM_LogicalDevice.
StatusInfo uint16
// The value of the CreationClassName property for the scoping system. This property is inherited from CIM_LogicalDevice.
SystemCreationClassName string
// The name of the scoping system. This property is inherited from CIM_LogicalDevice.
SystemName string
// The current bandwidth of the port in bits per second.
Speed uint64
// The maximum bandwidth of the port in bits per second.
MaxSpeed uint64
// The requested bandwidth of the port in bits per second.
RequestedSpeed uint64
// In cases where a port can be used for more than one function, this property indicates its primary usage.
UsageRestriction uint16
// The specific type of the port.
PortType uint16
// A string that describes the port type when the PortType property is set to 1 (Other).
OtherPortType string
// The network port type when the PortType property is set to 1 (Other).
OtherNetworkPortType string
// The port number.
PortNumber uint16
// The link technology for the port.
LinkTechnology uint16
// A string that describes the link technology when the LinkTechnology property is set to 1 (Other).
OtherLinkTechnology string
// The network address that is hardcoded into a port.
PermanentAddress string
// An array of network addresses for the port.
NetworkAddresses []string
// If TRUE, the port is operating in full duplex mode.
FullDuplex bool
// If TRUE, the port can automatically determine the speed or other communications characteristics of the attached network media.
AutoSense bool
// The maximum transmission unit (MTU) that can be supported.
SupportedMaximumTransmissionUnit uint64
// The active or negotiated maximum transmission unit (MTU) on the port.
ActiveMaximumTransmissionUnit uint64
// The description of the network interface.
InterfaceDescription string
// The name of the network interface.
InterfaceName string
// The network layer unique identifier (LUID) for the network interface.
NetLuid uint64
// The GUID for the network interface.
InterfaceGUID windows.GUID
// The index for the network interface.
InterfaceIndex uint32
// The name of the device object for the network adapter.
DeviceName string
// The index of the LUID for the network adapter.
NetLuidIndex uint32
// If TRUE, this is a virtual network adapter.
Virtual bool
// If TRUE, this network adapter is not displayed in the user interface.
Hidden bool
// If TRUE, this network adapter cannot be removed by a user.
NotUserRemovable bool
// If TRUE, this is an intermediate driver filter.
IMFilter bool
// The interface type as defined by the Internet Assigned Names Authority (IANA).
InterfaceType uint32
// If TRUE, this is a hardware interface.
HardwareInterface bool
// If TRUE, this is a WDM interface.
WdmInterface bool
// If TRUE, this is an endpoint interface.
EndPointInterface bool
// If TRUE, this is an iSCSI interface.
ISCSIInterface bool
// The current state of the network adapter.
State uint32
// The media type that the network adapter supports.
NdisMedium uint32
// The physical media type of the network adapter.
NdisPhysicalMedium uint32
// The operational status of the network interface.
InterfaceOperationalStatus uint32
// If TRUE, the operational status is down because the default port is not authenticated.
OperationalStatusDownDefaultPortNotAuthenticated bool
// If TRUE, the operational status is down because the media is disconnected.
OperationalStatusDownMediaDisconnected bool
// If TRUE, the operational status is down because the interface is paused.
OperationalStatusDownInterfacePaused bool
// If TRUE, the operational status is down because the interface is in a low power state.
OperationalStatusDownLowPowerState bool
// The administrative status of the network interface.
InterfaceAdminStatus uint32
// The media connect state of the network adapter.
MediaConnectState uint32
// The maximum transmission unit (MTU) size for the network adapter.
MtuSize uint32
// The VLAN identifier for the network adapter.
VlanID uint16
// The transmit link speed for the network adapter.
TransmitLinkSpeed uint64
// The receive link speed for the network adapter.
ReceiveLinkSpeed uint64
// If TRUE, the network adapter is in promiscuous mode.
PromiscuousMode bool
// If TRUE, the device is enabled to wake the system.
DeviceWakeUpEnable bool
// If TRUE, a connector is present on the network adapter.
ConnectorPresent bool
// The duplex state of the media.
MediaDuplexState uint32
// The date of the driver for the network adapter.
DriverDate string
// The date of the driver for the network adapter, in 100-nanosecond intervals.
DriverDateData uint64
// The version of the driver for the network adapter.
DriverVersionString string
// The name of the driver for the network adapter.
DriverName string
// The description of the driver for the network adapter.
DriverDescription string
// The major version of the driver for the network adapter.
MajorDriverVersion uint16
// The minor version of the driver for the network adapter.
MinorDriverVersion uint16
// The major NDIS version of the driver for the network adapter.
DriverMajorNdisVersion uint8
// The minor NDIS version of the driver for the network adapter.
DriverMinorNdisVersion uint8
// The provider of the driver for the network adapter.
DriverProvider string
// The component identifier for the network adapter.
ComponentID string
// The indices of the lower layer interfaces.
LowerLayerInterfaceIndices []uint32
// The indices of the higher layer interfaces.
HigherLayerInterfaceIndices []uint32
// If TRUE, the network adapter is administratively locked.
AdminLocked bool
handle *ole.IDispatch
}
// A NetAdapterSet contains one or more NetAdapters.
type NetAdapterSet struct {
NetAdapters []NetAdapter
}
// Close releases the handle to the network adapter.
func (n *NetAdapter) Close() {
if n.handle != nil {
n.handle.Release()
}
}
// GetNetAdapters queries for local network adapters.
//
// Close() must be called on the resulting NetAdapter to ensure all network adapters are released.
//
// Get all network adapters:
//
// svc.GetNetAdapters("")
//
// To get specific network adapters, provide a valid WMI query filter string, for example:
//
// svc.GetNetAdapters("WHERE Name='Wi-Fi'")
func (svc Service) GetNetAdapters(filter string) (NetAdapterSet, error) {
var netAdapters NetAdapterSet
query := "SELECT * FROM MSFT_NetAdapter"
if filter != "" {
query = fmt.Sprintf("%s %s", query, filter)
}
deck.InfoA(query).With(deck.V(1)).Go()
raw, err := oleutil.CallMethod(svc.wmiSvc, "ExecQuery", query)
if err != nil {
return netAdapters, fmt.Errorf("ExecQuery(%s): %w", query, err)
}
result := raw.ToIDispatch()
defer result.Release()
countVar, err := oleutil.GetProperty(result, "Count")
if err != nil {
return netAdapters, fmt.Errorf("oleutil.GetProperty(Count): %w", err)
}
count := int(countVar.Val)
for i := 0; i < count; i++ {
netAdapter := NetAdapter{}
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
if err != nil {
return netAdapters, fmt.Errorf("oleutil.CallMethod(ItemIndex, %d): %w", i, err)
}
netAdapter.handle = itemRaw.ToIDispatch()
if err := netAdapter.Query(); err != nil {
return netAdapters, fmt.Errorf("netAdapter.Query(): %w", err)
}
netAdapters.NetAdapters = append(netAdapters.NetAdapters, netAdapter)
}
return netAdapters, nil
}
// Disable disables the network adapter.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/netadaptercimprov/disable-msft-netadapter
func (n *NetAdapter) Disable() error {
if n.handle == nil {
return fmt.Errorf("invalid handle")
}
res, err := oleutil.CallMethod(n.handle, "Disable")
if err != nil {
return fmt.Errorf("Disable: %w", err)
}
if val, ok := res.Value().(int32); val != 0 || !ok {
return fmt.Errorf("error code returned during Disable: %d", val)
}
return nil
}
// Enable enables the network adapter.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/netadaptercimprov/enable-msft-netadapter
func (n *NetAdapter) Enable() error {
if n.handle == nil {
return fmt.Errorf("invalid handle")
}
res, err := oleutil.CallMethod(n.handle, "Enable")
if err != nil {
return fmt.Errorf("Enable: %w", err)
}
if val, ok := res.Value().(int32); val != 0 || !ok {
return fmt.Errorf("error code returned during Enable: %d", val)
}
return nil
}
// Rename renames the network adapter.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/netadaptercimprov/msft-netadapter-rename
func (n *NetAdapter) Rename(name string) error {
if n.handle == nil {
return fmt.Errorf("invalid handle")
}
res, err := oleutil.CallMethod(n.handle, "Rename", name)
if err != nil {
return fmt.Errorf("Rename: %w", err)
}
if val, ok := res.Value().(int32); val != 0 || !ok {
return fmt.Errorf("error code returned during Rename: %d", val)
}
return nil
}
// Query reads and populates the network adapter state from WMI.
func (n *NetAdapter) Query() error {
if n.handle == nil {
return fmt.Errorf("invalid handle")
}
// All the non-string/slice properties
for _, prop := range [][]any{
{"Availability", &n.Availability},
{"ErrorCleared", &n.ErrorCleared},
{"LastErrorCode", &n.LastErrorCode},
{"PowerManagementSupported", &n.PowerManagementSupported},
{"StatusInfo", &n.StatusInfo},
{"Speed", &n.Speed},
{"MaxSpeed", &n.MaxSpeed},
{"RequestedSpeed", &n.RequestedSpeed},
{"UsageRestriction", &n.UsageRestriction},
{"PortType", &n.PortType},
{"PortNumber", &n.PortNumber},
{"LinkTechnology", &n.LinkTechnology},
{"FullDuplex", &n.FullDuplex},
{"AutoSense", &n.AutoSense},
{"SupportedMaximumTransmissionUnit", &n.SupportedMaximumTransmissionUnit},
{"ActiveMaximumTransmissionUnit", &n.ActiveMaximumTransmissionUnit},
{"NetLuid", &n.NetLuid},
{"InterfaceIndex", &n.InterfaceIndex},
{"NetLuidIndex", &n.NetLuidIndex},
{"Virtual", &n.Virtual},
{"Hidden", &n.Hidden},
{"NotUserRemovable", &n.NotUserRemovable},
{"IMFilter", &n.IMFilter},
{"InterfaceType", &n.InterfaceType},
{"HardwareInterface", &n.HardwareInterface},
{"WdmInterface", &n.WdmInterface},
{"EndPointInterface", &n.EndPointInterface},
{"ISCSIInterface", &n.ISCSIInterface},
{"State", &n.State},
{"NdisMedium", &n.NdisMedium},
{"NdisPhysicalMedium", &n.NdisPhysicalMedium},
{"InterfaceOperationalStatus", &n.InterfaceOperationalStatus},
{"OperationalStatusDownDefaultPortNotAuthenticated", &n.OperationalStatusDownDefaultPortNotAuthenticated},
{"OperationalStatusDownMediaDisconnected", &n.OperationalStatusDownMediaDisconnected},
{"OperationalStatusDownInterfacePaused", &n.OperationalStatusDownInterfacePaused},
{"OperationalStatusDownLowPowerState", &n.OperationalStatusDownLowPowerState},
{"InterfaceAdminStatus", &n.InterfaceAdminStatus},
{"MediaConnectState", &n.MediaConnectState},
{"MtuSize", &n.MtuSize},
{"VlanID", &n.VlanID},
{"TransmitLinkSpeed", &n.TransmitLinkSpeed},
{"ReceiveLinkSpeed", &n.ReceiveLinkSpeed},
{"PromiscuousMode", &n.PromiscuousMode},
{"DeviceWakeUpEnable", &n.DeviceWakeUpEnable},
{"ConnectorPresent", &n.ConnectorPresent},
{"MediaDuplexState", &n.MediaDuplexState},
{"DriverDateData", &n.DriverDateData},
{"MajorDriverVersion", &n.MajorDriverVersion},
{"MinorDriverVersion", &n.MinorDriverVersion},
{"DriverMajorNdisVersion", &n.DriverMajorNdisVersion},
{"DriverMinorNdisVersion", &n.DriverMinorNdisVersion},
{"AdminLocked", &n.AdminLocked},
} {
val, err := oleutil.GetProperty(n.handle, prop[0].(string))
if err != nil {
return fmt.Errorf("oleutil.GetProperty(%s): %w", prop[0].(string), err)
}
if val.VT != ole.VT_NULL {
if err := AssignVariant(val.Value(), prop[1]); err != nil {
deck.Warningf("AssignVariant(%s): %v", prop[0].(string), err)
}
}
}
// String properties
for _, prop := range [][]any{
{"Name", &n.Name},
{"Status", &n.Status},
{"CreationClassName", &n.CreationClassName},
{"DeviceID", &n.DeviceID},
{"ErrorDescription", &n.ErrorDescription},
{"PNPDeviceID", &n.PNPDeviceID},
{"SystemCreationClassName", &n.SystemCreationClassName},
{"SystemName", &n.SystemName},
{"OtherPortType", &n.OtherPortType},
{"OtherNetworkPortType", &n.OtherNetworkPortType},
{"OtherLinkTechnology", &n.OtherLinkTechnology},
{"PermanentAddress", &n.PermanentAddress},
{"InterfaceDescription", &n.InterfaceDescription},
{"InterfaceName", &n.InterfaceName},
{"DeviceName", &n.DeviceName},
{"DriverDate", &n.DriverDate},
{"DriverVersionString", &n.DriverVersionString},
{"DriverName", &n.DriverName},
{"DriverDescription", &n.DriverDescription},
{"DriverProvider", &n.DriverProvider},
{"ComponentID", &n.ComponentID},
} {
val, err := oleutil.GetProperty(n.handle, prop[0].(string))
if err != nil {
return fmt.Errorf("oleutil.GetProperty(%s): %w", prop[0].(string), err)
}
if val.VT != ole.VT_NULL {
*(prop[1].(*string)) = val.ToString()
}
}
// GUID
prop, err := oleutil.GetProperty(n.handle, "InterfaceGUID")
if err != nil {
return fmt.Errorf("oleutil.GetProperty(InterfaceGUID): %w", err)
}
if prop.VT != ole.VT_NULL {
guid, err := windows.GUIDFromString(prop.ToString())
if err != nil {
return fmt.Errorf("GUIDFromString(%s): %w", prop.ToString(), err)
}
n.InterfaceGUID = guid
}
// Slice properties
// NetworkAddresses
prop, err = oleutil.GetProperty(n.handle, "NetworkAddresses")
if err != nil {
return fmt.Errorf("oleutil.GetProperty(NetworkAddresses): %w", err)
}
if prop.VT != ole.VT_NULL {
for _, v := range prop.ToArray().ToValueArray() {
s, ok := v.(string)
if !ok {
return fmt.Errorf("error converting NetworkAddress to string")
}
n.NetworkAddresses = append(n.NetworkAddresses, s)
}
}
// PowerManagementCapabilities
prop, err = oleutil.GetProperty(n.handle, "PowerManagementCapabilities")
if err != nil {
return fmt.Errorf("oleutil.GetProperty(PowerManagementCapabilities): %w", err)
}
if prop.VT != ole.VT_NULL {
for _, v := range prop.ToArray().ToValueArray() {
val, ok := v.(int32)
if !ok {
return fmt.Errorf("error converting PowerManagementCapabilities to uint16, got %T", v)
}
n.PowerManagementCapabilities = append(n.PowerManagementCapabilities, uint16(val))
}
}
// LowerLayerInterfaceIndices
prop, err = oleutil.GetProperty(n.handle, "LowerLayerInterfaceIndices")
if err != nil {
return fmt.Errorf("oleutil.GetProperty(LowerLayerInterfaceIndices): %w", err)
}
if prop.VT != ole.VT_NULL {
for _, v := range prop.ToArray().ToValueArray() {
val, ok := v.(int32)
if !ok {
return fmt.Errorf("error converting LowerLayerInterfaceIndices to uint32, got %T", v)
}
n.LowerLayerInterfaceIndices = append(n.LowerLayerInterfaceIndices, uint32(val))
}
}
// HigherLayerInterfaceIndices
prop, err = oleutil.GetProperty(n.handle, "HigherLayerInterfaceIndices")
if err != nil {
return fmt.Errorf("oleutil.GetProperty(HigherLayerInterfaceIndices): %w", err)
}
if prop.VT != ole.VT_NULL {
for _, v := range prop.ToArray().ToValueArray() {
val, ok := v.(int32)
if !ok {
return fmt.Errorf("error converting HigherLayerInterfaceIndices to uint32, got %T", v)
}
n.HigherLayerInterfaceIndices = append(n.HigherLayerInterfaceIndices, uint32(val))
}
}
return nil
}

View File

@@ -0,0 +1,203 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package netw
import (
"fmt"
"testing"
"golang.org/x/sys/windows"
)
func TestGetNetAdapters(t *testing.T) {
n, err := Connect()
if err != nil {
t.Fatalf("failed to connect to network: %v", err)
}
defer n.Close()
// First get all adapters to find a valid name for a filter test.
allAdapters, err := n.GetNetAdapters("")
if err != nil {
t.Fatalf("Initial GetNetAdapters() failed: %v", err)
}
if len(allAdapters.NetAdapters) == 0 {
t.Skip("No network adapters found, skipping test.")
}
adapterToFilter := allAdapters.NetAdapters[0]
testCases := []struct {
name string
filter string
wantErr bool
}{
{
name: "no filter",
filter: "",
},
{
name: "filter by name",
filter: fmt.Sprintf("WHERE Name = '%s'", adapterToFilter.Name),
},
{
name: "bad filter",
filter: fmt.Sprintf("WHERE OS = 'MacOS'"),
wantErr: true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
adapters, err := n.GetNetAdapters(tt.filter)
if (err != nil) != tt.wantErr {
t.Errorf("GetNetAdapters() error = %v, wantErr = %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Expected error, nothing more to check.
}
if len(adapters.NetAdapters) == 0 {
t.Fatal("got 0 adapters, want at least 1")
}
for _, adapter := range adapters.NetAdapters {
if adapter.Name == "" {
t.Error("adapter has an empty Name")
}
if adapter.InterfaceDescription == "" {
t.Errorf("adapter %q has an empty InterfaceDescription", adapter.Name)
}
if adapter.InterfaceGUID == (windows.GUID{}) {
t.Errorf("adapter %q has a zero InterfaceGUID", adapter.Name)
}
}
})
}
}
func TestRenameAdapter(t *testing.T) {
n, err := Connect()
if err != nil {
t.Fatalf("failed to connect to network: %v", err)
}
defer n.Close()
adapters, err := n.GetNetAdapters("")
if err != nil {
t.Fatalf("GetNetAdapters() failed: %v", err)
}
if len(adapters.NetAdapters) == 0 {
t.Fatal("GetNetAdapters() returned no adapters, cannot test rename.")
}
adapterToTest := &adapters.NetAdapters[0]
testCases := []struct {
name string
testName string
wantErr bool
}{
{
name: "simple rename",
testName: "cider-test-adapter",
},
{
name: "empty name",
testName: "",
wantErr: true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
// Get the current state of the adapter to have the correct original name for cleanup.
currentAdapterState, err := findAdapterByGUID(t, n, adapterToTest.InterfaceGUID)
if err != nil {
t.Fatalf("Could not find adapter by GUID before renaming: %v", err)
}
originalName := currentAdapterState.Name
// Defer cleanup to restore the original name at the end of the sub-test.
defer func() {
adapterForCleanup, findErr := findAdapterByGUID(t, n, adapterToTest.InterfaceGUID)
if findErr != nil {
t.Logf("Could not find adapter for cleanup: %v", findErr)
return
}
if err := adapterForCleanup.Rename(originalName); err != nil {
t.Errorf("Failed to rename adapter back to %q: %v", originalName, err)
}
}()
err = currentAdapterState.Rename(tt.testName)
if (err != nil) != tt.wantErr {
t.Fatalf("Rename() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return // Expected error, nothing more to check.
}
// Verify the rename worked.
renamedAdapter, err := findAdapterByGUID(t, n, adapterToTest.InterfaceGUID)
if err != nil {
t.Fatalf("Could not find adapter by GUID after renaming: %v", err)
}
if renamedAdapter.Name != tt.testName {
t.Errorf("adapter name is %q, want %q", renamedAdapter.Name, tt.testName)
}
})
}
}
// findAdapterByGUID is a helper to find a specific network adapter.
func findAdapterByGUID(t *testing.T, n Service, guid windows.GUID) (*NetAdapter, error) {
t.Helper()
adapters, err := n.GetNetAdapters("")
if err != nil {
return nil, fmt.Errorf("GetNetAdapters() failed: %w", err)
}
for i := range adapters.NetAdapters {
if adapters.NetAdapters[i].InterfaceGUID == guid {
return &adapters.NetAdapters[i], nil
}
}
return nil, fmt.Errorf("adapter with GUID %s not found", guid)
}
func TestEnableAlreadyEnabledAdapter(t *testing.T) {
n, err := Connect()
if err != nil {
t.Fatalf("failed to connect to network: %v", err)
}
defer n.Close()
// Filter for enabled adapters using InterfaceAdminStatus = 1 (Up).
adapters, err := n.GetNetAdapters("WHERE InterfaceAdminStatus = 1")
if err != nil {
t.Fatalf("GetNetAdapters() failed: %v", err)
}
if len(adapters.NetAdapters) == 0 {
t.Skip("No enabled network adapters found, skipping test.")
}
adapter := adapters.NetAdapters[0]
t.Logf("Attempting to enable already enabled adapter: %s", adapter.Name)
// Calling Enable on an already enabled adapter should be a no-op and not return an error.
if err := adapter.Enable(); err != nil {
t.Errorf("Enable() on already enabled adapter returned an error: %v", err)
}
}

View File

@@ -0,0 +1,224 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package netw
import (
"fmt"
"github.com/google/deck"
"github.com/scjalliance/comshim"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
// Service represents a connection to the host Storage service (in WMI).
type Service struct {
wmiIntf *ole.IDispatch
wmiSvc *ole.IDispatch
}
// AdapterConnect connects to the WMI provider for managing network adapter objects.
// You must call Close() to release the provider when finished.
//
// Example: network.AdapterConnect()
func AdapterConnect() (Service, error) {
comshim.Add(1)
svc := Service{}
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
comshim.Done()
return svc, fmt.Errorf("CreateObject: %w", err)
}
defer unknown.Release()
svc.wmiIntf, err = unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
comshim.Done()
return svc, fmt.Errorf("QueryInterface: %w", err)
}
serviceRaw, err := oleutil.CallMethod(svc.wmiIntf, "ConnectServer", nil)
if err != nil {
svc.Close()
return svc, fmt.Errorf("ConnectServer: %w", err)
}
svc.wmiSvc = serviceRaw.ToIDispatch()
return svc, nil
}
// AdapterConfiguration represents a Win32_NetworkAdapterConfiguration object.
//
// Ref: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
type AdapterConfiguration struct {
Caption string
Description string
DHCPEnabled bool
DHCPServer string
DNSDomain string
DNSHostName string
DNSServerSearchOrder []string
IPAddress []string
IPEnabled bool
IPSubnet []string
MACAddress string
DefaultIPGateway []string
ServiceName string
SettingID string
InterfaceIndex uint32
handle *ole.IDispatch
}
// An AdapterConfigurationSet contains one or more NetworkAdapterConfigurations.
type AdapterConfigurationSet struct {
NetworkAdapterConfigurations []AdapterConfiguration
}
// GetNetworkAdapterConfigurations returns an AdapterConfigurationSet.
func (svc Service) GetNetworkAdapterConfigurations(filter string) (AdapterConfigurationSet, error) {
var configs AdapterConfigurationSet
query := "SELECT * FROM Win32_NetworkAdapterConfiguration"
if filter != "" {
query = fmt.Sprintf("%s %s", query, filter)
}
deck.InfoA(query).With(deck.V(1)).Go()
raw, err := oleutil.CallMethod(svc.wmiSvc, "ExecQuery", query)
if err != nil {
return configs, fmt.Errorf("ExecQuery(%s): %w", query, err)
}
result := raw.ToIDispatch()
defer result.Release()
countVar, err := oleutil.GetProperty(result, "Count")
if err != nil {
return configs, fmt.Errorf("oleutil.GetProperty(Count): %w", err)
}
count := int(countVar.Val)
for i := 0; i < count; i++ {
config := AdapterConfiguration{}
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
if err != nil {
return configs, fmt.Errorf("oleutil.CallMethod(ItemIndex, %d): %w", i, err)
}
config.handle = itemRaw.ToIDispatch()
if err := config.Query(); err != nil {
return configs, err
}
configs.NetworkAdapterConfigurations = append(configs.NetworkAdapterConfigurations, config)
}
return configs, nil
}
// Query reads and populates the network adapter state from WMI.
func (n *AdapterConfiguration) Query() error {
if n.handle == nil {
return fmt.Errorf("invalid handle")
}
// Non-string/slice properties
for _, prop := range [][]any{
{"DHCPEnabled", &n.DHCPEnabled},
{"IPEnabled", &n.IPEnabled},
{"InterfaceIndex", &n.InterfaceIndex},
} {
name, ok := prop[0].(string)
if !ok {
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
}
val, err := oleutil.GetProperty(n.handle, name)
if err != nil {
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
}
if val.VT != ole.VT_NULL {
if err := AssignVariant(val.Value(), prop[1]); err != nil {
deck.Warningf("AssignVariant(%s): %v", name, err)
}
}
}
// String properties
for _, prop := range [][]any{
{"Caption", &n.Caption},
{"Description", &n.Description},
{"DHCPServer", &n.DHCPServer},
{"DNSDomain", &n.DNSDomain},
{"DNSHostName", &n.DNSHostName},
{"MACAddress", &n.MACAddress},
{"ServiceName", &n.ServiceName},
{"SettingID", &n.SettingID},
} {
name, ok := prop[0].(string)
if !ok {
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
}
val, err := oleutil.GetProperty(n.handle, name)
if err != nil {
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
}
if val.VT != ole.VT_NULL {
*(prop[1].(*string)) = val.ToString()
}
}
// Slice properties
for _, sliceProp := range []struct {
Name string
Dst *[]string
}{
{"DNSServerSearchOrder", &n.DNSServerSearchOrder},
{"IPAddress", &n.IPAddress},
{"IPSubnet", &n.IPSubnet},
{"DefaultIPGateway", &n.DefaultIPGateway},
} {
prop, err := oleutil.GetProperty(n.handle, sliceProp.Name)
if err != nil {
deck.Warningf("oleutil.GetProperty(%s): %v", sliceProp.Name, err)
continue
}
if prop.VT != ole.VT_NULL {
for _, v := range prop.ToArray().ToValueArray() {
s, ok := v.(string)
if !ok {
deck.Warningf("error converting %s to string", sliceProp.Name)
} else {
*sliceProp.Dst = append(*sliceProp.Dst, s)
}
}
}
}
return nil
}
// StaticRoute sets the static route for the network adapter.
//
// IMPORTANT: This method ONLY supports Ipv4.
//
// Ref: https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/enablestatic-method-in-class-win32-networkadapterconfiguration
func (n *AdapterConfiguration) StaticRoute(ipaddress []string, subnetMask []string) error {
res, err := oleutil.CallMethod(n.handle, "EnableStatic", ipaddress, subnetMask)
if err != nil {
return fmt.Errorf("EnableStatic: %w", err)
}
if val, ok := res.Value().(int32); val != 0 || !ok {
return fmt.Errorf("error code returned during EnableStatic: %d", val)
}
return nil
}

View File

@@ -0,0 +1,75 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package netw
import (
"fmt"
"testing"
)
func TestGetNetAdapterConfigs(t *testing.T) {
n, err := AdapterConnect()
if err != nil {
t.Fatalf("failed to connect to network: %v", err)
}
defer n.Close()
// First get all configs to find a valid one for a filter test.
allConfigs, err := n.GetNetworkAdapterConfigurations("")
if err != nil {
t.Fatalf("Initial GetNetAdapterConfigs() failed: %v", err)
}
if len(allConfigs.NetworkAdapterConfigurations) == 0 {
t.Skip("No net adapter configs found, skipping test.")
}
configToFilter := allConfigs.NetworkAdapterConfigurations[0]
testCases := []struct {
name string
filter string
wantErr bool
}{
{
name: "no filter",
filter: "",
},
{
name: "filter by interface index",
filter: fmt.Sprintf("WHERE InterfaceIndex = %d", configToFilter.InterfaceIndex),
},
{
name: "bad filter",
filter: "WHERE BadFilter = 'true'",
wantErr: true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
configs, err := n.GetNetworkAdapterConfigurations(tt.filter)
if (err != nil) != tt.wantErr {
t.Errorf("GetNetAdapterConfigs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Expected error, nothing more to check.
}
if len(configs.NetworkAdapterConfigurations) == 0 {
t.Fatal("got 0 net adapter configs, want at least 1")
}
})
}
}

View File

@@ -0,0 +1,124 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package netw provides network adapter management functionality.
package netw
import (
"fmt"
"strconv"
"github.com/scjalliance/comshim"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
// Connect connects to the WMI provider for managing storage objects.
// You must call Close() to release the provider when finished.
//
// Example: storage.Connect()
func Connect() (Service, error) {
comshim.Add(1)
svc := Service{}
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
if err != nil {
comshim.Done()
return svc, fmt.Errorf("CreateObject: %w", err)
}
defer unknown.Release()
svc.wmiIntf, err = unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
comshim.Done()
return svc, fmt.Errorf("QueryInterface: %w", err)
}
serviceRaw, err := oleutil.CallMethod(svc.wmiIntf, "ConnectServer", nil, `\\.\ROOT\StandardCimv2`)
if err != nil {
svc.Close()
return svc, fmt.Errorf("ConnectServer: %w", err)
}
svc.wmiSvc = serviceRaw.ToIDispatch()
return svc, nil
}
// Close frees all resources associated with a volume.
func (svc *Service) Close() {
svc.wmiIntf.Release()
if svc.wmiSvc != nil {
svc.wmiSvc.Release()
}
comshim.Done()
}
// AssignVariant assigns a variant to a destination.
func AssignVariant(v any, dst any) error {
switch d := dst.(type) {
case *bool:
b, ok := v.(bool)
if !ok {
return fmt.Errorf("cannot assign %T to *bool", v)
}
*d = b
case *uint8:
switch val := v.(type) {
case uint8:
*d = val
case int16:
*d = uint8(val)
case int32:
*d = uint8(val)
default:
return fmt.Errorf("cannot assign %T to *uint8", v)
}
case *uint16:
i, ok := v.(int32)
if !ok {
i16, ok16 := v.(int16)
if !ok16 {
return fmt.Errorf("cannot assign %T to *uint16", v)
}
i = int32(i16)
}
*d = uint16(i)
case *uint32:
i, ok := v.(int32)
if !ok {
return fmt.Errorf("cannot assign %T to *uint32", v)
}
*d = uint32(i)
case *uint64:
s, ok := v.(string)
if ok {
parsed, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return fmt.Errorf("cannot parse uint64 from string '%s': %w", s, err)
}
*d = parsed
return nil
}
i, ok := v.(int64)
if !ok {
i32, ok32 := v.(int32)
if !ok32 {
return fmt.Errorf("cannot assign %T to *uint64", v)
}
i = int64(i32)
}
*d = uint64(i)
default:
return fmt.Errorf("unsupported destination type %T", dst)
}
return nil
}