mirror of
https://github.com/google/glazier.git
synced 2025-12-19 18:27:35 -05:00
* Add helper functions for;
+ MSFT_NetIPAddress
+ MSFT_NetAdapter
+ Disable, Enable, Rename support
+ Win32_NetworkAdapterConfiguration
* Add unit tests for all of the above.
PiperOrigin-RevId: 785509372
This commit is contained in:
committed by
Copybara-Service
parent
5c26fe5353
commit
0ed520ba4a
225
go/network/iphelper/ipaddress.go
Normal file
225
go/network/iphelper/ipaddress.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License")
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/deck"
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
)
|
||||
|
||||
// IPAddress represents a MSFT_NetIPAddress object.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/nettcpipprov/msft-netipaddress
|
||||
type IPAddress struct {
|
||||
AddressFamily uint16
|
||||
AddressState uint16
|
||||
InterfaceAlias string
|
||||
InterfaceIndex uint32
|
||||
IPAddress string
|
||||
PreferredLifetime string
|
||||
PrefixOrigin uint16
|
||||
SkipAsSource bool
|
||||
Store uint8
|
||||
SuffixOrigin uint16
|
||||
Type uint8
|
||||
ValidLifetime string
|
||||
|
||||
// handle is the internal ole handle
|
||||
handle *ole.IDispatch
|
||||
}
|
||||
|
||||
// CreateIPAddressOptions represents the options for creating an IP address.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/nettcpipprov/create-msft-netipaddress
|
||||
type CreateIPAddressOptions struct {
|
||||
InterfaceIndex uint32
|
||||
InterfaceAlias string
|
||||
IPAddress string
|
||||
AddressFamily uint16
|
||||
PrefixLength uint8
|
||||
Type uint8
|
||||
PrefixOrigin uint8
|
||||
SuffixOrigin uint8
|
||||
AddressState uint16
|
||||
ValidLifetime string // CIM_DATETIME
|
||||
PreferredLifetime string // CIM_DATETIME
|
||||
SkipAsSource bool
|
||||
DefaultGateway string
|
||||
PolicyStore string
|
||||
PassThru bool
|
||||
}
|
||||
|
||||
// IPAddressSet contains one or more IPAddresses.
|
||||
type IPAddressSet struct {
|
||||
IPAddresses []IPAddress
|
||||
}
|
||||
|
||||
// GetIPAddresses returns a IPAddresses struct.
|
||||
//
|
||||
// Get all IP addresses:
|
||||
//
|
||||
// svc.GetIPAddresses("")
|
||||
//
|
||||
// To get specific IP addresses, provide a valid WMI query filter string, for example:
|
||||
//
|
||||
// svc.GetIPAddresses("WHERE IPAddress='192.168.1.1'")
|
||||
func (svc Service) GetIPAddresses(filter string) (IPAddressSet, error) {
|
||||
var ipset IPAddressSet
|
||||
query := "SELECT * FROM MSFT_NetIPAddress"
|
||||
if filter != "" {
|
||||
query = fmt.Sprintf("%s %s", query, filter)
|
||||
}
|
||||
|
||||
deck.InfoA(query).With(deck.V(1)).Go()
|
||||
raw, err := oleutil.CallMethod(svc.wmiSvc, "ExecQuery", query)
|
||||
if err != nil {
|
||||
return ipset, fmt.Errorf("ExecQuery(%s): %w", query, err)
|
||||
}
|
||||
result := raw.ToIDispatch()
|
||||
defer result.Release()
|
||||
|
||||
countVar, err := oleutil.GetProperty(result, "Count")
|
||||
if err != nil {
|
||||
return ipset, fmt.Errorf("oleutil.GetProperty(Count): %w", err)
|
||||
}
|
||||
count := int(countVar.Val)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
ipresult := IPAddress{}
|
||||
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
|
||||
if err != nil {
|
||||
return ipset, fmt.Errorf("oleutil.CallMethod(ItemIndex, %d): %w", i, err)
|
||||
}
|
||||
ipresult.handle = itemRaw.ToIDispatch()
|
||||
|
||||
if err := ipresult.Query(); err != nil {
|
||||
return ipset, fmt.Errorf("ipresult.Query(): %w", err)
|
||||
}
|
||||
|
||||
ipset.IPAddresses = append(ipset.IPAddresses, ipresult)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// Close releases the handle to the IP address.
|
||||
func (ip *IPAddress) Close() {
|
||||
if ip.handle != nil {
|
||||
ip.handle.Release()
|
||||
}
|
||||
}
|
||||
|
||||
// Query reads and populates the IP address state from WMI.
|
||||
func (ip *IPAddress) Query() error {
|
||||
if ip.handle == nil {
|
||||
return fmt.Errorf("invalid handle")
|
||||
}
|
||||
|
||||
// All the non-string/slice properties
|
||||
for _, prop := range [][]any{
|
||||
{"AddressFamily", &ip.AddressFamily},
|
||||
{"InterfaceIndex", &ip.InterfaceIndex},
|
||||
{"PrefixOrigin", &ip.PrefixOrigin},
|
||||
{"SkipAsSource", &ip.SkipAsSource},
|
||||
{"SuffixOrigin", &ip.SuffixOrigin},
|
||||
{"Type", &ip.Type},
|
||||
{"Store", &ip.Store},
|
||||
{"AddressState", &ip.AddressState},
|
||||
} {
|
||||
name, ok := prop[0].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
|
||||
}
|
||||
val, err := oleutil.GetProperty(ip.handle, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
|
||||
}
|
||||
if val.VT != ole.VT_NULL {
|
||||
if err := AssignVariant(val.Value(), prop[1]); err != nil {
|
||||
deck.Warningf("AssignVariant(%s): %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String properties
|
||||
for _, prop := range [][]any{
|
||||
{"InterfaceAlias", &ip.InterfaceAlias},
|
||||
{"IPAddress", &ip.IPAddress},
|
||||
{"ValidLifetime", &ip.ValidLifetime},
|
||||
{"PreferredLifetime", &ip.PreferredLifetime},
|
||||
} {
|
||||
name, ok := prop[0].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
|
||||
}
|
||||
val, err := oleutil.GetProperty(ip.handle, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
|
||||
}
|
||||
if val.VT != ole.VT_NULL {
|
||||
*(prop[1].(*string)) = val.ToString()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPOutput represents the output of the Create method.
|
||||
type IPOutput struct{}
|
||||
|
||||
// Create creates the IP address on the current instance.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/nettcpipprov/create-msft-netipaddress
|
||||
func (ip *IPAddress) Create(opts CreateIPAddressOptions) (IPOutput, error) {
|
||||
ipset := IPOutput{}
|
||||
|
||||
return ipset, fmt.Errorf("not implemented")
|
||||
|
||||
// var createdObject ole.VARIANT
|
||||
// ole.VariantInit(&createdObject)
|
||||
|
||||
// Parameters must be passed in the order defined by the WMI method signature.
|
||||
// res, err := oleutil.CallMethod(ip.handle, "Create",
|
||||
// opts.InterfaceIndex,
|
||||
// opts.InterfaceAlias,
|
||||
// opts.IPAddress,
|
||||
// opts.AddressFamily,
|
||||
// opts.PrefixLength,
|
||||
// opts.Type,
|
||||
// opts.PrefixOrigin,
|
||||
// opts.SuffixOrigin,
|
||||
// opts.AddressState,
|
||||
// opts.ValidLifetime,
|
||||
// opts.PreferredLifetime,
|
||||
// opts.SkipAsSource,
|
||||
// opts.DefaultGateway,
|
||||
// opts.PolicyStore,
|
||||
// opts.PassThru,
|
||||
// &createdObject, // output
|
||||
// )
|
||||
// if err != nil {
|
||||
// return ipset, fmt.Errorf("Create: %w", err)
|
||||
// }
|
||||
// if val, ok := res.Value().(int32); val != 0 || !ok {
|
||||
// return ipset, fmt.Errorf("error code returned during create: %d", val)
|
||||
// }
|
||||
|
||||
// ip.handle = createdObject.ToIDispatch()
|
||||
|
||||
// return ipset, nil
|
||||
}
|
||||
84
go/network/iphelper/ipaddress_test.go
Normal file
84
go/network/iphelper/ipaddress_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License")
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetIPAddresses(t *testing.T) {
|
||||
n, err := Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to network: %v", err)
|
||||
}
|
||||
defer n.Close()
|
||||
|
||||
// First get all IPs to find a valid address for a filter test.
|
||||
allIPs, err := n.GetIPAddresses("")
|
||||
if err != nil {
|
||||
t.Fatalf("Initial GetIPAddresses() failed: %v", err)
|
||||
}
|
||||
if len(allIPs.IPAddresses) == 0 {
|
||||
t.Skip("No IP addresses found, skipping test.")
|
||||
}
|
||||
ipToFilter := allIPs.IPAddresses[0]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filter string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no filter",
|
||||
filter: "",
|
||||
},
|
||||
{
|
||||
name: "filter by ip address",
|
||||
filter: fmt.Sprintf("WHERE IPAddress = '%s'", ipToFilter.IPAddress),
|
||||
},
|
||||
{
|
||||
name: "bad filter",
|
||||
filter: "WHERE BadFilter = 'true'",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ips, err := n.GetIPAddresses(tt.filter)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GetIPAddresses() error = %v, wantErr = %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return // Expected error, nothing more to check.
|
||||
}
|
||||
|
||||
if len(ips.IPAddresses) == 0 {
|
||||
t.Fatal("got 0 ip addresses, want at least 1")
|
||||
}
|
||||
|
||||
for _, ip := range ips.IPAddresses {
|
||||
if ip.IPAddress == "" {
|
||||
t.Error("ip has an empty IPAddress")
|
||||
}
|
||||
if ip.InterfaceIndex == 0 {
|
||||
t.Errorf("ip %q has a zero InterfaceIndex", ip.IPAddress)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
485
go/network/iphelper/netadapter.go
Normal file
485
go/network/iphelper/netadapter.go
Normal file
@@ -0,0 +1,485 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/deck"
|
||||
"golang.org/x/sys/windows"
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
)
|
||||
|
||||
// NetAdapter represents a MSFT_NetAdapter object. It's important to note that some fields are read-only.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
type NetAdapter struct {
|
||||
// The name of the network adapter. This property is inherited from CIM_ManagedSystemElement.
|
||||
Name string
|
||||
// Current status of the object. This property is inherited from CIM_ManagedSystemElement.
|
||||
Status string
|
||||
// The availability and status of the device. This property is inherited from CIM_LogicalDevice.
|
||||
Availability uint16
|
||||
// If TRUE, the device is using a user-defined configuration. This property is inherited from CIM_LogicalDevice.
|
||||
ConfigManagerUserConfig bool
|
||||
// The name of the class or subclass used in the creation of an instance. This property is inherited from CIM_LogicalDevice.
|
||||
CreationClassName string
|
||||
// The address or other identifying information for the network adapter. This property is inherited from CIM_LogicalDevice.
|
||||
DeviceID string
|
||||
// If TRUE, the error reported in the LastErrorCode property is now cleared. This property is inherited from CIM_LogicalDevice.
|
||||
ErrorCleared bool
|
||||
// A string that provides more information about the error recorded in the LastErrorCode property and information about any corrective actions that can be taken. This property is inherited from CIM_LogicalDevice.
|
||||
ErrorDescription string
|
||||
// The last error code reported by the logical device. This property is inherited from CIM_LogicalDevice.
|
||||
LastErrorCode uint32
|
||||
// The Plug and Play device identifier of the logical device. This property is inherited from CIM_LogicalDevice.
|
||||
PNPDeviceID string
|
||||
// An array of the specific power-related capabilities of a logical device. This property is inherited from CIM_LogicalDevice.
|
||||
PowerManagementCapabilities []uint16
|
||||
// If TRUE, the device can be power managed. This property is inherited from CIM_LogicalDevice.
|
||||
PowerManagementSupported bool
|
||||
// The status of the logical device. This property is inherited from CIM_LogicalDevice.
|
||||
StatusInfo uint16
|
||||
// The value of the CreationClassName property for the scoping system. This property is inherited from CIM_LogicalDevice.
|
||||
SystemCreationClassName string
|
||||
// The name of the scoping system. This property is inherited from CIM_LogicalDevice.
|
||||
SystemName string
|
||||
// The current bandwidth of the port in bits per second.
|
||||
Speed uint64
|
||||
// The maximum bandwidth of the port in bits per second.
|
||||
MaxSpeed uint64
|
||||
// The requested bandwidth of the port in bits per second.
|
||||
RequestedSpeed uint64
|
||||
// In cases where a port can be used for more than one function, this property indicates its primary usage.
|
||||
UsageRestriction uint16
|
||||
// The specific type of the port.
|
||||
PortType uint16
|
||||
// A string that describes the port type when the PortType property is set to 1 (Other).
|
||||
OtherPortType string
|
||||
// The network port type when the PortType property is set to 1 (Other).
|
||||
OtherNetworkPortType string
|
||||
// The port number.
|
||||
PortNumber uint16
|
||||
// The link technology for the port.
|
||||
LinkTechnology uint16
|
||||
// A string that describes the link technology when the LinkTechnology property is set to 1 (Other).
|
||||
OtherLinkTechnology string
|
||||
// The network address that is hardcoded into a port.
|
||||
PermanentAddress string
|
||||
// An array of network addresses for the port.
|
||||
NetworkAddresses []string
|
||||
// If TRUE, the port is operating in full duplex mode.
|
||||
FullDuplex bool
|
||||
// If TRUE, the port can automatically determine the speed or other communications characteristics of the attached network media.
|
||||
AutoSense bool
|
||||
// The maximum transmission unit (MTU) that can be supported.
|
||||
SupportedMaximumTransmissionUnit uint64
|
||||
// The active or negotiated maximum transmission unit (MTU) on the port.
|
||||
ActiveMaximumTransmissionUnit uint64
|
||||
// The description of the network interface.
|
||||
InterfaceDescription string
|
||||
// The name of the network interface.
|
||||
InterfaceName string
|
||||
// The network layer unique identifier (LUID) for the network interface.
|
||||
NetLuid uint64
|
||||
// The GUID for the network interface.
|
||||
InterfaceGUID windows.GUID
|
||||
// The index for the network interface.
|
||||
InterfaceIndex uint32
|
||||
// The name of the device object for the network adapter.
|
||||
DeviceName string
|
||||
// The index of the LUID for the network adapter.
|
||||
NetLuidIndex uint32
|
||||
// If TRUE, this is a virtual network adapter.
|
||||
Virtual bool
|
||||
// If TRUE, this network adapter is not displayed in the user interface.
|
||||
Hidden bool
|
||||
// If TRUE, this network adapter cannot be removed by a user.
|
||||
NotUserRemovable bool
|
||||
// If TRUE, this is an intermediate driver filter.
|
||||
IMFilter bool
|
||||
// The interface type as defined by the Internet Assigned Names Authority (IANA).
|
||||
InterfaceType uint32
|
||||
// If TRUE, this is a hardware interface.
|
||||
HardwareInterface bool
|
||||
// If TRUE, this is a WDM interface.
|
||||
WdmInterface bool
|
||||
// If TRUE, this is an endpoint interface.
|
||||
EndPointInterface bool
|
||||
// If TRUE, this is an iSCSI interface.
|
||||
ISCSIInterface bool
|
||||
// The current state of the network adapter.
|
||||
State uint32
|
||||
// The media type that the network adapter supports.
|
||||
NdisMedium uint32
|
||||
// The physical media type of the network adapter.
|
||||
NdisPhysicalMedium uint32
|
||||
// The operational status of the network interface.
|
||||
InterfaceOperationalStatus uint32
|
||||
// If TRUE, the operational status is down because the default port is not authenticated.
|
||||
OperationalStatusDownDefaultPortNotAuthenticated bool
|
||||
// If TRUE, the operational status is down because the media is disconnected.
|
||||
OperationalStatusDownMediaDisconnected bool
|
||||
// If TRUE, the operational status is down because the interface is paused.
|
||||
OperationalStatusDownInterfacePaused bool
|
||||
// If TRUE, the operational status is down because the interface is in a low power state.
|
||||
OperationalStatusDownLowPowerState bool
|
||||
// The administrative status of the network interface.
|
||||
InterfaceAdminStatus uint32
|
||||
// The media connect state of the network adapter.
|
||||
MediaConnectState uint32
|
||||
// The maximum transmission unit (MTU) size for the network adapter.
|
||||
MtuSize uint32
|
||||
// The VLAN identifier for the network adapter.
|
||||
VlanID uint16
|
||||
// The transmit link speed for the network adapter.
|
||||
TransmitLinkSpeed uint64
|
||||
// The receive link speed for the network adapter.
|
||||
ReceiveLinkSpeed uint64
|
||||
// If TRUE, the network adapter is in promiscuous mode.
|
||||
PromiscuousMode bool
|
||||
// If TRUE, the device is enabled to wake the system.
|
||||
DeviceWakeUpEnable bool
|
||||
// If TRUE, a connector is present on the network adapter.
|
||||
ConnectorPresent bool
|
||||
// The duplex state of the media.
|
||||
MediaDuplexState uint32
|
||||
// The date of the driver for the network adapter.
|
||||
DriverDate string
|
||||
// The date of the driver for the network adapter, in 100-nanosecond intervals.
|
||||
DriverDateData uint64
|
||||
// The version of the driver for the network adapter.
|
||||
DriverVersionString string
|
||||
// The name of the driver for the network adapter.
|
||||
DriverName string
|
||||
// The description of the driver for the network adapter.
|
||||
DriverDescription string
|
||||
// The major version of the driver for the network adapter.
|
||||
MajorDriverVersion uint16
|
||||
// The minor version of the driver for the network adapter.
|
||||
MinorDriverVersion uint16
|
||||
// The major NDIS version of the driver for the network adapter.
|
||||
DriverMajorNdisVersion uint8
|
||||
// The minor NDIS version of the driver for the network adapter.
|
||||
DriverMinorNdisVersion uint8
|
||||
// The provider of the driver for the network adapter.
|
||||
DriverProvider string
|
||||
// The component identifier for the network adapter.
|
||||
ComponentID string
|
||||
// The indices of the lower layer interfaces.
|
||||
LowerLayerInterfaceIndices []uint32
|
||||
// The indices of the higher layer interfaces.
|
||||
HigherLayerInterfaceIndices []uint32
|
||||
// If TRUE, the network adapter is administratively locked.
|
||||
AdminLocked bool
|
||||
|
||||
handle *ole.IDispatch
|
||||
}
|
||||
|
||||
// A NetAdapterSet contains one or more NetAdapters.
|
||||
type NetAdapterSet struct {
|
||||
NetAdapters []NetAdapter
|
||||
}
|
||||
|
||||
// Close releases the handle to the network adapter.
|
||||
func (n *NetAdapter) Close() {
|
||||
if n.handle != nil {
|
||||
n.handle.Release()
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetAdapters queries for local network adapters.
|
||||
//
|
||||
// Close() must be called on the resulting NetAdapter to ensure all network adapters are released.
|
||||
//
|
||||
// Get all network adapters:
|
||||
//
|
||||
// svc.GetNetAdapters("")
|
||||
//
|
||||
// To get specific network adapters, provide a valid WMI query filter string, for example:
|
||||
//
|
||||
// svc.GetNetAdapters("WHERE Name='Wi-Fi'")
|
||||
func (svc Service) GetNetAdapters(filter string) (NetAdapterSet, error) {
|
||||
var netAdapters NetAdapterSet
|
||||
query := "SELECT * FROM MSFT_NetAdapter"
|
||||
if filter != "" {
|
||||
query = fmt.Sprintf("%s %s", query, filter)
|
||||
}
|
||||
|
||||
deck.InfoA(query).With(deck.V(1)).Go()
|
||||
raw, err := oleutil.CallMethod(svc.wmiSvc, "ExecQuery", query)
|
||||
if err != nil {
|
||||
return netAdapters, fmt.Errorf("ExecQuery(%s): %w", query, err)
|
||||
}
|
||||
result := raw.ToIDispatch()
|
||||
defer result.Release()
|
||||
|
||||
countVar, err := oleutil.GetProperty(result, "Count")
|
||||
if err != nil {
|
||||
return netAdapters, fmt.Errorf("oleutil.GetProperty(Count): %w", err)
|
||||
}
|
||||
count := int(countVar.Val)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
netAdapter := NetAdapter{}
|
||||
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
|
||||
if err != nil {
|
||||
return netAdapters, fmt.Errorf("oleutil.CallMethod(ItemIndex, %d): %w", i, err)
|
||||
}
|
||||
netAdapter.handle = itemRaw.ToIDispatch()
|
||||
|
||||
if err := netAdapter.Query(); err != nil {
|
||||
return netAdapters, fmt.Errorf("netAdapter.Query(): %w", err)
|
||||
}
|
||||
|
||||
netAdapters.NetAdapters = append(netAdapters.NetAdapters, netAdapter)
|
||||
}
|
||||
|
||||
return netAdapters, nil
|
||||
}
|
||||
|
||||
// Disable disables the network adapter.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/netadaptercimprov/disable-msft-netadapter
|
||||
func (n *NetAdapter) Disable() error {
|
||||
if n.handle == nil {
|
||||
return fmt.Errorf("invalid handle")
|
||||
}
|
||||
res, err := oleutil.CallMethod(n.handle, "Disable")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Disable: %w", err)
|
||||
}
|
||||
if val, ok := res.Value().(int32); val != 0 || !ok {
|
||||
return fmt.Errorf("error code returned during Disable: %d", val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enable enables the network adapter.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/netadaptercimprov/enable-msft-netadapter
|
||||
func (n *NetAdapter) Enable() error {
|
||||
if n.handle == nil {
|
||||
return fmt.Errorf("invalid handle")
|
||||
}
|
||||
res, err := oleutil.CallMethod(n.handle, "Enable")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Enable: %w", err)
|
||||
}
|
||||
if val, ok := res.Value().(int32); val != 0 || !ok {
|
||||
return fmt.Errorf("error code returned during Enable: %d", val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rename renames the network adapter.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/fwp/wmi/netadaptercimprov/msft-netadapter-rename
|
||||
func (n *NetAdapter) Rename(name string) error {
|
||||
if n.handle == nil {
|
||||
return fmt.Errorf("invalid handle")
|
||||
}
|
||||
res, err := oleutil.CallMethod(n.handle, "Rename", name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Rename: %w", err)
|
||||
}
|
||||
if val, ok := res.Value().(int32); val != 0 || !ok {
|
||||
return fmt.Errorf("error code returned during Rename: %d", val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query reads and populates the network adapter state from WMI.
|
||||
func (n *NetAdapter) Query() error {
|
||||
if n.handle == nil {
|
||||
return fmt.Errorf("invalid handle")
|
||||
}
|
||||
|
||||
// All the non-string/slice properties
|
||||
for _, prop := range [][]any{
|
||||
{"Availability", &n.Availability},
|
||||
{"ErrorCleared", &n.ErrorCleared},
|
||||
{"LastErrorCode", &n.LastErrorCode},
|
||||
{"PowerManagementSupported", &n.PowerManagementSupported},
|
||||
{"StatusInfo", &n.StatusInfo},
|
||||
{"Speed", &n.Speed},
|
||||
{"MaxSpeed", &n.MaxSpeed},
|
||||
{"RequestedSpeed", &n.RequestedSpeed},
|
||||
{"UsageRestriction", &n.UsageRestriction},
|
||||
{"PortType", &n.PortType},
|
||||
{"PortNumber", &n.PortNumber},
|
||||
{"LinkTechnology", &n.LinkTechnology},
|
||||
{"FullDuplex", &n.FullDuplex},
|
||||
{"AutoSense", &n.AutoSense},
|
||||
{"SupportedMaximumTransmissionUnit", &n.SupportedMaximumTransmissionUnit},
|
||||
{"ActiveMaximumTransmissionUnit", &n.ActiveMaximumTransmissionUnit},
|
||||
{"NetLuid", &n.NetLuid},
|
||||
{"InterfaceIndex", &n.InterfaceIndex},
|
||||
{"NetLuidIndex", &n.NetLuidIndex},
|
||||
{"Virtual", &n.Virtual},
|
||||
{"Hidden", &n.Hidden},
|
||||
{"NotUserRemovable", &n.NotUserRemovable},
|
||||
{"IMFilter", &n.IMFilter},
|
||||
{"InterfaceType", &n.InterfaceType},
|
||||
{"HardwareInterface", &n.HardwareInterface},
|
||||
{"WdmInterface", &n.WdmInterface},
|
||||
{"EndPointInterface", &n.EndPointInterface},
|
||||
{"ISCSIInterface", &n.ISCSIInterface},
|
||||
{"State", &n.State},
|
||||
{"NdisMedium", &n.NdisMedium},
|
||||
{"NdisPhysicalMedium", &n.NdisPhysicalMedium},
|
||||
{"InterfaceOperationalStatus", &n.InterfaceOperationalStatus},
|
||||
{"OperationalStatusDownDefaultPortNotAuthenticated", &n.OperationalStatusDownDefaultPortNotAuthenticated},
|
||||
{"OperationalStatusDownMediaDisconnected", &n.OperationalStatusDownMediaDisconnected},
|
||||
{"OperationalStatusDownInterfacePaused", &n.OperationalStatusDownInterfacePaused},
|
||||
{"OperationalStatusDownLowPowerState", &n.OperationalStatusDownLowPowerState},
|
||||
{"InterfaceAdminStatus", &n.InterfaceAdminStatus},
|
||||
{"MediaConnectState", &n.MediaConnectState},
|
||||
{"MtuSize", &n.MtuSize},
|
||||
{"VlanID", &n.VlanID},
|
||||
{"TransmitLinkSpeed", &n.TransmitLinkSpeed},
|
||||
{"ReceiveLinkSpeed", &n.ReceiveLinkSpeed},
|
||||
{"PromiscuousMode", &n.PromiscuousMode},
|
||||
{"DeviceWakeUpEnable", &n.DeviceWakeUpEnable},
|
||||
{"ConnectorPresent", &n.ConnectorPresent},
|
||||
{"MediaDuplexState", &n.MediaDuplexState},
|
||||
{"DriverDateData", &n.DriverDateData},
|
||||
{"MajorDriverVersion", &n.MajorDriverVersion},
|
||||
{"MinorDriverVersion", &n.MinorDriverVersion},
|
||||
{"DriverMajorNdisVersion", &n.DriverMajorNdisVersion},
|
||||
{"DriverMinorNdisVersion", &n.DriverMinorNdisVersion},
|
||||
{"AdminLocked", &n.AdminLocked},
|
||||
} {
|
||||
val, err := oleutil.GetProperty(n.handle, prop[0].(string))
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(%s): %w", prop[0].(string), err)
|
||||
}
|
||||
if val.VT != ole.VT_NULL {
|
||||
if err := AssignVariant(val.Value(), prop[1]); err != nil {
|
||||
deck.Warningf("AssignVariant(%s): %v", prop[0].(string), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String properties
|
||||
for _, prop := range [][]any{
|
||||
{"Name", &n.Name},
|
||||
{"Status", &n.Status},
|
||||
{"CreationClassName", &n.CreationClassName},
|
||||
{"DeviceID", &n.DeviceID},
|
||||
{"ErrorDescription", &n.ErrorDescription},
|
||||
{"PNPDeviceID", &n.PNPDeviceID},
|
||||
{"SystemCreationClassName", &n.SystemCreationClassName},
|
||||
{"SystemName", &n.SystemName},
|
||||
{"OtherPortType", &n.OtherPortType},
|
||||
{"OtherNetworkPortType", &n.OtherNetworkPortType},
|
||||
{"OtherLinkTechnology", &n.OtherLinkTechnology},
|
||||
{"PermanentAddress", &n.PermanentAddress},
|
||||
{"InterfaceDescription", &n.InterfaceDescription},
|
||||
{"InterfaceName", &n.InterfaceName},
|
||||
{"DeviceName", &n.DeviceName},
|
||||
{"DriverDate", &n.DriverDate},
|
||||
{"DriverVersionString", &n.DriverVersionString},
|
||||
{"DriverName", &n.DriverName},
|
||||
{"DriverDescription", &n.DriverDescription},
|
||||
{"DriverProvider", &n.DriverProvider},
|
||||
{"ComponentID", &n.ComponentID},
|
||||
} {
|
||||
val, err := oleutil.GetProperty(n.handle, prop[0].(string))
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(%s): %w", prop[0].(string), err)
|
||||
}
|
||||
if val.VT != ole.VT_NULL {
|
||||
*(prop[1].(*string)) = val.ToString()
|
||||
}
|
||||
}
|
||||
|
||||
// GUID
|
||||
prop, err := oleutil.GetProperty(n.handle, "InterfaceGUID")
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(InterfaceGUID): %w", err)
|
||||
}
|
||||
if prop.VT != ole.VT_NULL {
|
||||
guid, err := windows.GUIDFromString(prop.ToString())
|
||||
if err != nil {
|
||||
return fmt.Errorf("GUIDFromString(%s): %w", prop.ToString(), err)
|
||||
}
|
||||
n.InterfaceGUID = guid
|
||||
}
|
||||
|
||||
// Slice properties
|
||||
// NetworkAddresses
|
||||
prop, err = oleutil.GetProperty(n.handle, "NetworkAddresses")
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(NetworkAddresses): %w", err)
|
||||
}
|
||||
if prop.VT != ole.VT_NULL {
|
||||
for _, v := range prop.ToArray().ToValueArray() {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting NetworkAddress to string")
|
||||
}
|
||||
n.NetworkAddresses = append(n.NetworkAddresses, s)
|
||||
}
|
||||
}
|
||||
|
||||
// PowerManagementCapabilities
|
||||
prop, err = oleutil.GetProperty(n.handle, "PowerManagementCapabilities")
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(PowerManagementCapabilities): %w", err)
|
||||
}
|
||||
if prop.VT != ole.VT_NULL {
|
||||
for _, v := range prop.ToArray().ToValueArray() {
|
||||
val, ok := v.(int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting PowerManagementCapabilities to uint16, got %T", v)
|
||||
}
|
||||
n.PowerManagementCapabilities = append(n.PowerManagementCapabilities, uint16(val))
|
||||
}
|
||||
}
|
||||
|
||||
// LowerLayerInterfaceIndices
|
||||
prop, err = oleutil.GetProperty(n.handle, "LowerLayerInterfaceIndices")
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(LowerLayerInterfaceIndices): %w", err)
|
||||
}
|
||||
if prop.VT != ole.VT_NULL {
|
||||
for _, v := range prop.ToArray().ToValueArray() {
|
||||
val, ok := v.(int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting LowerLayerInterfaceIndices to uint32, got %T", v)
|
||||
}
|
||||
n.LowerLayerInterfaceIndices = append(n.LowerLayerInterfaceIndices, uint32(val))
|
||||
}
|
||||
}
|
||||
|
||||
// HigherLayerInterfaceIndices
|
||||
prop, err = oleutil.GetProperty(n.handle, "HigherLayerInterfaceIndices")
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(HigherLayerInterfaceIndices): %w", err)
|
||||
}
|
||||
if prop.VT != ole.VT_NULL {
|
||||
for _, v := range prop.ToArray().ToValueArray() {
|
||||
val, ok := v.(int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting HigherLayerInterfaceIndices to uint32, got %T", v)
|
||||
}
|
||||
n.HigherLayerInterfaceIndices = append(n.HigherLayerInterfaceIndices, uint32(val))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
203
go/network/iphelper/netadapter_test.go
Normal file
203
go/network/iphelper/netadapter_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License")
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestGetNetAdapters(t *testing.T) {
|
||||
n, err := Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to network: %v", err)
|
||||
}
|
||||
defer n.Close()
|
||||
|
||||
// First get all adapters to find a valid name for a filter test.
|
||||
allAdapters, err := n.GetNetAdapters("")
|
||||
if err != nil {
|
||||
t.Fatalf("Initial GetNetAdapters() failed: %v", err)
|
||||
}
|
||||
if len(allAdapters.NetAdapters) == 0 {
|
||||
t.Skip("No network adapters found, skipping test.")
|
||||
}
|
||||
adapterToFilter := allAdapters.NetAdapters[0]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filter string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no filter",
|
||||
filter: "",
|
||||
},
|
||||
{
|
||||
name: "filter by name",
|
||||
filter: fmt.Sprintf("WHERE Name = '%s'", adapterToFilter.Name),
|
||||
},
|
||||
{
|
||||
name: "bad filter",
|
||||
filter: fmt.Sprintf("WHERE OS = 'MacOS'"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
adapters, err := n.GetNetAdapters(tt.filter)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GetNetAdapters() error = %v, wantErr = %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return // Expected error, nothing more to check.
|
||||
}
|
||||
|
||||
if len(adapters.NetAdapters) == 0 {
|
||||
t.Fatal("got 0 adapters, want at least 1")
|
||||
}
|
||||
|
||||
for _, adapter := range adapters.NetAdapters {
|
||||
if adapter.Name == "" {
|
||||
t.Error("adapter has an empty Name")
|
||||
}
|
||||
if adapter.InterfaceDescription == "" {
|
||||
t.Errorf("adapter %q has an empty InterfaceDescription", adapter.Name)
|
||||
}
|
||||
if adapter.InterfaceGUID == (windows.GUID{}) {
|
||||
t.Errorf("adapter %q has a zero InterfaceGUID", adapter.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameAdapter(t *testing.T) {
|
||||
n, err := Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to network: %v", err)
|
||||
}
|
||||
defer n.Close()
|
||||
adapters, err := n.GetNetAdapters("")
|
||||
if err != nil {
|
||||
t.Fatalf("GetNetAdapters() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(adapters.NetAdapters) == 0 {
|
||||
t.Fatal("GetNetAdapters() returned no adapters, cannot test rename.")
|
||||
}
|
||||
|
||||
adapterToTest := &adapters.NetAdapters[0]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
testName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple rename",
|
||||
testName: "cider-test-adapter",
|
||||
},
|
||||
{
|
||||
name: "empty name",
|
||||
testName: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Get the current state of the adapter to have the correct original name for cleanup.
|
||||
currentAdapterState, err := findAdapterByGUID(t, n, adapterToTest.InterfaceGUID)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not find adapter by GUID before renaming: %v", err)
|
||||
}
|
||||
originalName := currentAdapterState.Name
|
||||
|
||||
// Defer cleanup to restore the original name at the end of the sub-test.
|
||||
defer func() {
|
||||
adapterForCleanup, findErr := findAdapterByGUID(t, n, adapterToTest.InterfaceGUID)
|
||||
if findErr != nil {
|
||||
t.Logf("Could not find adapter for cleanup: %v", findErr)
|
||||
return
|
||||
}
|
||||
if err := adapterForCleanup.Rename(originalName); err != nil {
|
||||
t.Errorf("Failed to rename adapter back to %q: %v", originalName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = currentAdapterState.Rename(tt.testName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("Rename() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return // Expected error, nothing more to check.
|
||||
}
|
||||
|
||||
// Verify the rename worked.
|
||||
renamedAdapter, err := findAdapterByGUID(t, n, adapterToTest.InterfaceGUID)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not find adapter by GUID after renaming: %v", err)
|
||||
}
|
||||
if renamedAdapter.Name != tt.testName {
|
||||
t.Errorf("adapter name is %q, want %q", renamedAdapter.Name, tt.testName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// findAdapterByGUID is a helper to find a specific network adapter.
|
||||
func findAdapterByGUID(t *testing.T, n Service, guid windows.GUID) (*NetAdapter, error) {
|
||||
t.Helper()
|
||||
adapters, err := n.GetNetAdapters("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetNetAdapters() failed: %w", err)
|
||||
}
|
||||
for i := range adapters.NetAdapters {
|
||||
if adapters.NetAdapters[i].InterfaceGUID == guid {
|
||||
return &adapters.NetAdapters[i], nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("adapter with GUID %s not found", guid)
|
||||
}
|
||||
|
||||
func TestEnableAlreadyEnabledAdapter(t *testing.T) {
|
||||
n, err := Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to network: %v", err)
|
||||
}
|
||||
defer n.Close()
|
||||
|
||||
// Filter for enabled adapters using InterfaceAdminStatus = 1 (Up).
|
||||
adapters, err := n.GetNetAdapters("WHERE InterfaceAdminStatus = 1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetNetAdapters() failed: %v", err)
|
||||
}
|
||||
if len(adapters.NetAdapters) == 0 {
|
||||
t.Skip("No enabled network adapters found, skipping test.")
|
||||
}
|
||||
|
||||
adapter := adapters.NetAdapters[0]
|
||||
t.Logf("Attempting to enable already enabled adapter: %s", adapter.Name)
|
||||
|
||||
// Calling Enable on an already enabled adapter should be a no-op and not return an error.
|
||||
if err := adapter.Enable(); err != nil {
|
||||
t.Errorf("Enable() on already enabled adapter returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
224
go/network/iphelper/netadapterconfig.go
Normal file
224
go/network/iphelper/netadapterconfig.go
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/deck"
|
||||
"github.com/scjalliance/comshim"
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
)
|
||||
|
||||
// Service represents a connection to the host Storage service (in WMI).
|
||||
type Service struct {
|
||||
wmiIntf *ole.IDispatch
|
||||
wmiSvc *ole.IDispatch
|
||||
}
|
||||
|
||||
// AdapterConnect connects to the WMI provider for managing network adapter objects.
|
||||
// You must call Close() to release the provider when finished.
|
||||
//
|
||||
// Example: network.AdapterConnect()
|
||||
func AdapterConnect() (Service, error) {
|
||||
comshim.Add(1)
|
||||
svc := Service{}
|
||||
|
||||
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
|
||||
if err != nil {
|
||||
comshim.Done()
|
||||
return svc, fmt.Errorf("CreateObject: %w", err)
|
||||
}
|
||||
defer unknown.Release()
|
||||
svc.wmiIntf, err = unknown.QueryInterface(ole.IID_IDispatch)
|
||||
if err != nil {
|
||||
comshim.Done()
|
||||
return svc, fmt.Errorf("QueryInterface: %w", err)
|
||||
}
|
||||
serviceRaw, err := oleutil.CallMethod(svc.wmiIntf, "ConnectServer", nil)
|
||||
if err != nil {
|
||||
svc.Close()
|
||||
return svc, fmt.Errorf("ConnectServer: %w", err)
|
||||
}
|
||||
svc.wmiSvc = serviceRaw.ToIDispatch()
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// AdapterConfiguration represents a Win32_NetworkAdapterConfiguration object.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
type AdapterConfiguration struct {
|
||||
Caption string
|
||||
Description string
|
||||
DHCPEnabled bool
|
||||
DHCPServer string
|
||||
DNSDomain string
|
||||
DNSHostName string
|
||||
DNSServerSearchOrder []string
|
||||
IPAddress []string
|
||||
IPEnabled bool
|
||||
IPSubnet []string
|
||||
MACAddress string
|
||||
DefaultIPGateway []string
|
||||
ServiceName string
|
||||
SettingID string
|
||||
InterfaceIndex uint32
|
||||
|
||||
handle *ole.IDispatch
|
||||
}
|
||||
|
||||
// An AdapterConfigurationSet contains one or more NetworkAdapterConfigurations.
|
||||
type AdapterConfigurationSet struct {
|
||||
NetworkAdapterConfigurations []AdapterConfiguration
|
||||
}
|
||||
|
||||
// GetNetworkAdapterConfigurations returns an AdapterConfigurationSet.
|
||||
func (svc Service) GetNetworkAdapterConfigurations(filter string) (AdapterConfigurationSet, error) {
|
||||
var configs AdapterConfigurationSet
|
||||
query := "SELECT * FROM Win32_NetworkAdapterConfiguration"
|
||||
if filter != "" {
|
||||
query = fmt.Sprintf("%s %s", query, filter)
|
||||
}
|
||||
|
||||
deck.InfoA(query).With(deck.V(1)).Go()
|
||||
raw, err := oleutil.CallMethod(svc.wmiSvc, "ExecQuery", query)
|
||||
if err != nil {
|
||||
return configs, fmt.Errorf("ExecQuery(%s): %w", query, err)
|
||||
}
|
||||
result := raw.ToIDispatch()
|
||||
defer result.Release()
|
||||
|
||||
countVar, err := oleutil.GetProperty(result, "Count")
|
||||
if err != nil {
|
||||
return configs, fmt.Errorf("oleutil.GetProperty(Count): %w", err)
|
||||
}
|
||||
count := int(countVar.Val)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
config := AdapterConfiguration{}
|
||||
itemRaw, err := oleutil.CallMethod(result, "ItemIndex", i)
|
||||
if err != nil {
|
||||
return configs, fmt.Errorf("oleutil.CallMethod(ItemIndex, %d): %w", i, err)
|
||||
}
|
||||
config.handle = itemRaw.ToIDispatch()
|
||||
|
||||
if err := config.Query(); err != nil {
|
||||
return configs, err
|
||||
}
|
||||
|
||||
configs.NetworkAdapterConfigurations = append(configs.NetworkAdapterConfigurations, config)
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// Query reads and populates the network adapter state from WMI.
|
||||
func (n *AdapterConfiguration) Query() error {
|
||||
if n.handle == nil {
|
||||
return fmt.Errorf("invalid handle")
|
||||
}
|
||||
|
||||
// Non-string/slice properties
|
||||
for _, prop := range [][]any{
|
||||
{"DHCPEnabled", &n.DHCPEnabled},
|
||||
{"IPEnabled", &n.IPEnabled},
|
||||
{"InterfaceIndex", &n.InterfaceIndex},
|
||||
} {
|
||||
name, ok := prop[0].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
|
||||
}
|
||||
val, err := oleutil.GetProperty(n.handle, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
|
||||
}
|
||||
if val.VT != ole.VT_NULL {
|
||||
if err := AssignVariant(val.Value(), prop[1]); err != nil {
|
||||
deck.Warningf("AssignVariant(%s): %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String properties
|
||||
for _, prop := range [][]any{
|
||||
{"Caption", &n.Caption},
|
||||
{"Description", &n.Description},
|
||||
{"DHCPServer", &n.DHCPServer},
|
||||
{"DNSDomain", &n.DNSDomain},
|
||||
{"DNSHostName", &n.DNSHostName},
|
||||
{"MACAddress", &n.MACAddress},
|
||||
{"ServiceName", &n.ServiceName},
|
||||
{"SettingID", &n.SettingID},
|
||||
} {
|
||||
name, ok := prop[0].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to convert property name to string: %v", prop[0])
|
||||
}
|
||||
val, err := oleutil.GetProperty(n.handle, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oleutil.GetProperty(%s): %w", name, err)
|
||||
}
|
||||
if val.VT != ole.VT_NULL {
|
||||
*(prop[1].(*string)) = val.ToString()
|
||||
}
|
||||
}
|
||||
|
||||
// Slice properties
|
||||
for _, sliceProp := range []struct {
|
||||
Name string
|
||||
Dst *[]string
|
||||
}{
|
||||
{"DNSServerSearchOrder", &n.DNSServerSearchOrder},
|
||||
{"IPAddress", &n.IPAddress},
|
||||
{"IPSubnet", &n.IPSubnet},
|
||||
{"DefaultIPGateway", &n.DefaultIPGateway},
|
||||
} {
|
||||
prop, err := oleutil.GetProperty(n.handle, sliceProp.Name)
|
||||
if err != nil {
|
||||
deck.Warningf("oleutil.GetProperty(%s): %v", sliceProp.Name, err)
|
||||
continue
|
||||
}
|
||||
if prop.VT != ole.VT_NULL {
|
||||
for _, v := range prop.ToArray().ToValueArray() {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
deck.Warningf("error converting %s to string", sliceProp.Name)
|
||||
} else {
|
||||
*sliceProp.Dst = append(*sliceProp.Dst, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StaticRoute sets the static route for the network adapter.
|
||||
//
|
||||
// IMPORTANT: This method ONLY supports Ipv4.
|
||||
//
|
||||
// Ref: https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/enablestatic-method-in-class-win32-networkadapterconfiguration
|
||||
func (n *AdapterConfiguration) StaticRoute(ipaddress []string, subnetMask []string) error {
|
||||
res, err := oleutil.CallMethod(n.handle, "EnableStatic", ipaddress, subnetMask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("EnableStatic: %w", err)
|
||||
}
|
||||
if val, ok := res.Value().(int32); val != 0 || !ok {
|
||||
return fmt.Errorf("error code returned during EnableStatic: %d", val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
75
go/network/iphelper/netadapterconfig_test.go
Normal file
75
go/network/iphelper/netadapterconfig_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License")
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetNetAdapterConfigs(t *testing.T) {
|
||||
n, err := AdapterConnect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to network: %v", err)
|
||||
}
|
||||
defer n.Close()
|
||||
|
||||
// First get all configs to find a valid one for a filter test.
|
||||
allConfigs, err := n.GetNetworkAdapterConfigurations("")
|
||||
if err != nil {
|
||||
t.Fatalf("Initial GetNetAdapterConfigs() failed: %v", err)
|
||||
}
|
||||
if len(allConfigs.NetworkAdapterConfigurations) == 0 {
|
||||
t.Skip("No net adapter configs found, skipping test.")
|
||||
}
|
||||
configToFilter := allConfigs.NetworkAdapterConfigurations[0]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filter string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no filter",
|
||||
filter: "",
|
||||
},
|
||||
{
|
||||
name: "filter by interface index",
|
||||
filter: fmt.Sprintf("WHERE InterfaceIndex = %d", configToFilter.InterfaceIndex),
|
||||
},
|
||||
{
|
||||
name: "bad filter",
|
||||
filter: "WHERE BadFilter = 'true'",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configs, err := n.GetNetworkAdapterConfigurations(tt.filter)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GetNetAdapterConfigs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return // Expected error, nothing more to check.
|
||||
}
|
||||
|
||||
if len(configs.NetworkAdapterConfigurations) == 0 {
|
||||
t.Fatal("got 0 net adapter configs, want at least 1")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
124
go/network/iphelper/network.go
Normal file
124
go/network/iphelper/network.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package netw provides network adapter management functionality.
|
||||
package netw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/scjalliance/comshim"
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
)
|
||||
|
||||
// Connect connects to the WMI provider for managing storage objects.
|
||||
// You must call Close() to release the provider when finished.
|
||||
//
|
||||
// Example: storage.Connect()
|
||||
func Connect() (Service, error) {
|
||||
comshim.Add(1)
|
||||
svc := Service{}
|
||||
|
||||
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
|
||||
if err != nil {
|
||||
comshim.Done()
|
||||
return svc, fmt.Errorf("CreateObject: %w", err)
|
||||
}
|
||||
defer unknown.Release()
|
||||
svc.wmiIntf, err = unknown.QueryInterface(ole.IID_IDispatch)
|
||||
if err != nil {
|
||||
comshim.Done()
|
||||
return svc, fmt.Errorf("QueryInterface: %w", err)
|
||||
}
|
||||
serviceRaw, err := oleutil.CallMethod(svc.wmiIntf, "ConnectServer", nil, `\\.\ROOT\StandardCimv2`)
|
||||
if err != nil {
|
||||
svc.Close()
|
||||
return svc, fmt.Errorf("ConnectServer: %w", err)
|
||||
}
|
||||
svc.wmiSvc = serviceRaw.ToIDispatch()
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// Close frees all resources associated with a volume.
|
||||
func (svc *Service) Close() {
|
||||
svc.wmiIntf.Release()
|
||||
if svc.wmiSvc != nil {
|
||||
svc.wmiSvc.Release()
|
||||
}
|
||||
comshim.Done()
|
||||
}
|
||||
|
||||
// AssignVariant assigns a variant to a destination.
|
||||
func AssignVariant(v any, dst any) error {
|
||||
switch d := dst.(type) {
|
||||
case *bool:
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot assign %T to *bool", v)
|
||||
}
|
||||
*d = b
|
||||
case *uint8:
|
||||
switch val := v.(type) {
|
||||
case uint8:
|
||||
*d = val
|
||||
case int16:
|
||||
*d = uint8(val)
|
||||
case int32:
|
||||
*d = uint8(val)
|
||||
default:
|
||||
return fmt.Errorf("cannot assign %T to *uint8", v)
|
||||
}
|
||||
case *uint16:
|
||||
i, ok := v.(int32)
|
||||
if !ok {
|
||||
i16, ok16 := v.(int16)
|
||||
if !ok16 {
|
||||
return fmt.Errorf("cannot assign %T to *uint16", v)
|
||||
}
|
||||
i = int32(i16)
|
||||
}
|
||||
*d = uint16(i)
|
||||
case *uint32:
|
||||
i, ok := v.(int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot assign %T to *uint32", v)
|
||||
}
|
||||
*d = uint32(i)
|
||||
case *uint64:
|
||||
s, ok := v.(string)
|
||||
if ok {
|
||||
parsed, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse uint64 from string '%s': %w", s, err)
|
||||
}
|
||||
*d = parsed
|
||||
return nil
|
||||
}
|
||||
i, ok := v.(int64)
|
||||
if !ok {
|
||||
i32, ok32 := v.(int32)
|
||||
if !ok32 {
|
||||
return fmt.Errorf("cannot assign %T to *uint64", v)
|
||||
}
|
||||
i = int64(i32)
|
||||
}
|
||||
*d = uint64(i)
|
||||
default:
|
||||
return fmt.Errorf("unsupported destination type %T", dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user