wmi.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. // +build windows
  2. /*
  3. Package wmi provides a WQL interface for WMI on Windows.
  4. Example code to print names of running processes:
  5. type Win32_Process struct {
  6. Name string
  7. }
  8. func main() {
  9. var dst []Win32_Process
  10. q := wmi.CreateQuery(&dst, "")
  11. err := wmi.Query(q, &dst)
  12. if err != nil {
  13. log.Fatal(err)
  14. }
  15. for i, v := range dst {
  16. println(i, v.Name)
  17. }
  18. }
  19. */
  20. package wmi
  21. import (
  22. "bytes"
  23. "errors"
  24. "fmt"
  25. "log"
  26. "os"
  27. "reflect"
  28. "runtime"
  29. "strconv"
  30. "strings"
  31. "sync"
  32. "time"
  33. "github.com/go-ole/go-ole"
  34. "github.com/go-ole/go-ole/oleutil"
  35. )
  36. var l = log.New(os.Stdout, "", log.LstdFlags)
  37. var (
  38. ErrInvalidEntityType = errors.New("wmi: invalid entity type")
  39. // ErrNilCreateObject is the error returned if CreateObject returns nil even
  40. // if the error was nil.
  41. ErrNilCreateObject = errors.New("wmi: create object returned nil")
  42. lock sync.Mutex
  43. )
  44. // S_FALSE is returned by CoInitializeEx if it was already called on this thread.
  45. const S_FALSE = 0x00000001
  46. // QueryNamespace invokes Query with the given namespace on the local machine.
  47. func QueryNamespace(query string, dst interface{}, namespace string) error {
  48. return Query(query, dst, nil, namespace)
  49. }
  50. // Query runs the WQL query and appends the values to dst.
  51. //
  52. // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
  53. // the query must have the same name in dst. Supported types are all signed and
  54. // unsigned integers, time.Time, string, bool, or a pointer to one of those.
  55. // Array types are not supported.
  56. //
  57. // By default, the local machine and default namespace are used. These can be
  58. // changed using connectServerArgs. See
  59. // https://docs.microsoft.com/en-us/windows/desktop/WmiSdk/swbemlocator-connectserver
  60. // for details.
  61. //
  62. // Query is a wrapper around DefaultClient.Query.
  63. func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
  64. if DefaultClient.SWbemServicesClient == nil {
  65. return DefaultClient.Query(query, dst, connectServerArgs...)
  66. }
  67. return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
  68. }
  69. // CallMethod calls a method named methodName on an instance of the class named
  70. // className, with the given params.
  71. //
  72. // CallMethod is a wrapper around DefaultClient.CallMethod.
  73. func CallMethod(connectServerArgs []interface{}, className, methodName string, params []interface{}) (int32, error) {
  74. return DefaultClient.CallMethod(connectServerArgs, className, methodName, params)
  75. }
  76. // A Client is an WMI query client.
  77. //
  78. // Its zero value (DefaultClient) is a usable client.
  79. type Client struct {
  80. // NonePtrZero specifies if nil values for fields which aren't pointers
  81. // should be returned as the field types zero value.
  82. //
  83. // Setting this to true allows stucts without pointer fields to be used
  84. // without the risk failure should a nil value returned from WMI.
  85. NonePtrZero bool
  86. // PtrNil specifies if nil values for pointer fields should be returned
  87. // as nil.
  88. //
  89. // Setting this to true will set pointer fields to nil where WMI
  90. // returned nil, otherwise the types zero value will be returned.
  91. PtrNil bool
  92. // AllowMissingFields specifies that struct fields not present in the
  93. // query result should not result in an error.
  94. //
  95. // Setting this to true allows custom queries to be used with full
  96. // struct definitions instead of having to define multiple structs.
  97. AllowMissingFields bool
  98. // SWbemServiceClient is an optional SWbemServices object that can be
  99. // initialized and then reused across multiple queries. If it is null
  100. // then the method will initialize a new temporary client each time.
  101. SWbemServicesClient *SWbemServices
  102. }
  103. // DefaultClient is the default Client and is used by Query, QueryNamespace, and CallMethod.
  104. var DefaultClient = &Client{}
  105. // coinitService coinitializes WMI service. If no error is returned, a cleanup function
  106. // is returned which must be executed (usually deferred) to clean up allocated resources.
  107. func (c *Client) coinitService(connectServerArgs ...interface{}) (*ole.IDispatch, func(), error) {
  108. var unknown *ole.IUnknown
  109. var wmi *ole.IDispatch
  110. var serviceRaw *ole.VARIANT
  111. // be sure teardown happens in the reverse
  112. // order from that which they were created
  113. deferFn := func() {
  114. if serviceRaw != nil {
  115. serviceRaw.Clear()
  116. }
  117. if wmi != nil {
  118. wmi.Release()
  119. }
  120. if unknown != nil {
  121. unknown.Release()
  122. }
  123. ole.CoUninitialize()
  124. }
  125. // if we error'ed here, clean up immediately
  126. var err error
  127. defer func() {
  128. if err != nil {
  129. deferFn()
  130. }
  131. }()
  132. err = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
  133. if err != nil {
  134. oleCode := err.(*ole.OleError).Code()
  135. if oleCode != ole.S_OK && oleCode != S_FALSE {
  136. return nil, nil, err
  137. }
  138. }
  139. unknown, err = oleutil.CreateObject("WbemScripting.SWbemLocator")
  140. if err != nil {
  141. return nil, nil, err
  142. } else if unknown == nil {
  143. return nil, nil, ErrNilCreateObject
  144. }
  145. wmi, err = unknown.QueryInterface(ole.IID_IDispatch)
  146. if err != nil {
  147. return nil, nil, err
  148. }
  149. // service is a SWbemServices
  150. serviceRaw, err = oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
  151. if err != nil {
  152. return nil, nil, err
  153. }
  154. return serviceRaw.ToIDispatch(), deferFn, nil
  155. }
  156. // CallMethod calls a WMI method named methodName on an instance
  157. // of the class named className. It passes in the arguments given
  158. // in params. Use connectServerArgs to customize the machine and
  159. // namespace; by default, the local machine and default namespace
  160. // are used. See
  161. // https://docs.microsoft.com/en-us/windows/desktop/WmiSdk/swbemlocator-connectserver
  162. // for details.
  163. func (c *Client) CallMethod(connectServerArgs []interface{}, className, methodName string, params []interface{}) (int32, error) {
  164. service, cleanup, err := c.coinitService(connectServerArgs...)
  165. if err != nil {
  166. return 0, fmt.Errorf("coinit: %v", err)
  167. }
  168. defer cleanup()
  169. // Get class
  170. classRaw, err := oleutil.CallMethod(service, "Get", className)
  171. if err != nil {
  172. return 0, fmt.Errorf("CallMethod Get class %s: %v", className, err)
  173. }
  174. class := classRaw.ToIDispatch()
  175. defer classRaw.Clear()
  176. // Run method
  177. resultRaw, err := oleutil.CallMethod(class, methodName, params...)
  178. if err != nil {
  179. return 0, fmt.Errorf("CallMethod %s.%s: %v", className, methodName, err)
  180. }
  181. resultInt, ok := resultRaw.Value().(int32)
  182. if !ok {
  183. return 0, fmt.Errorf("return value was not an int32: %v (%T)", resultRaw, resultRaw)
  184. }
  185. return resultInt, nil
  186. }
  187. // Query runs the WQL query and appends the values to dst.
  188. //
  189. // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
  190. // the query must have the same name in dst. Supported types are all signed and
  191. // unsigned integers, time.Time, string, bool, or a pointer to one of those.
  192. // Array types are not supported.
  193. //
  194. // By default, the local machine and default namespace are used. These can be
  195. // changed using connectServerArgs. See
  196. // https://docs.microsoft.com/en-us/windows/desktop/WmiSdk/swbemlocator-connectserver
  197. // for details.
  198. func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
  199. dv := reflect.ValueOf(dst)
  200. if dv.Kind() != reflect.Ptr || dv.IsNil() {
  201. return ErrInvalidEntityType
  202. }
  203. dv = dv.Elem()
  204. mat, elemType := checkMultiArg(dv)
  205. if mat == multiArgTypeInvalid {
  206. return ErrInvalidEntityType
  207. }
  208. lock.Lock()
  209. defer lock.Unlock()
  210. runtime.LockOSThread()
  211. defer runtime.UnlockOSThread()
  212. service, cleanup, err := c.coinitService(connectServerArgs...)
  213. if err != nil {
  214. return err
  215. }
  216. defer cleanup()
  217. // result is a SWBemObjectSet
  218. resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
  219. if err != nil {
  220. return err
  221. }
  222. result := resultRaw.ToIDispatch()
  223. defer resultRaw.Clear()
  224. count, err := oleInt64(result, "Count")
  225. if err != nil {
  226. return err
  227. }
  228. enumProperty, err := result.GetProperty("_NewEnum")
  229. if err != nil {
  230. return err
  231. }
  232. defer enumProperty.Clear()
  233. enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
  234. if err != nil {
  235. return err
  236. }
  237. if enum == nil {
  238. return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
  239. }
  240. defer enum.Release()
  241. // Initialize a slice with Count capacity
  242. dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
  243. var errFieldMismatch error
  244. for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
  245. if err != nil {
  246. return err
  247. }
  248. err := func() error {
  249. // item is a SWbemObject, but really a Win32_Process
  250. item := itemRaw.ToIDispatch()
  251. defer item.Release()
  252. ev := reflect.New(elemType)
  253. if err = c.loadEntity(ev.Interface(), item); err != nil {
  254. if _, ok := err.(*ErrFieldMismatch); ok {
  255. // We continue loading entities even in the face of field mismatch errors.
  256. // If we encounter any other error, that other error is returned. Otherwise,
  257. // an ErrFieldMismatch is returned.
  258. errFieldMismatch = err
  259. } else {
  260. return err
  261. }
  262. }
  263. if mat != multiArgTypeStructPtr {
  264. ev = ev.Elem()
  265. }
  266. dv.Set(reflect.Append(dv, ev))
  267. return nil
  268. }()
  269. if err != nil {
  270. return err
  271. }
  272. }
  273. return errFieldMismatch
  274. }
  275. // ErrFieldMismatch is returned when a field is to be loaded into a different
  276. // type than the one it was stored from, or when a field is missing or
  277. // unexported in the destination struct.
  278. // StructType is the type of the struct pointed to by the destination argument.
  279. type ErrFieldMismatch struct {
  280. StructType reflect.Type
  281. FieldName string
  282. Reason string
  283. }
  284. func (e *ErrFieldMismatch) Error() string {
  285. return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
  286. e.FieldName, e.StructType, e.Reason)
  287. }
  288. var timeType = reflect.TypeOf(time.Time{})
  289. // loadEntity loads a SWbemObject into a struct pointer.
  290. func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
  291. v := reflect.ValueOf(dst).Elem()
  292. for i := 0; i < v.NumField(); i++ {
  293. f := v.Field(i)
  294. of := f
  295. isPtr := f.Kind() == reflect.Ptr
  296. if isPtr {
  297. ptr := reflect.New(f.Type().Elem())
  298. f.Set(ptr)
  299. f = f.Elem()
  300. }
  301. n := v.Type().Field(i).Name
  302. if n[0] < 'A' || n[0] > 'Z' {
  303. continue
  304. }
  305. if !f.CanSet() {
  306. return &ErrFieldMismatch{
  307. StructType: of.Type(),
  308. FieldName: n,
  309. Reason: "CanSet() is false",
  310. }
  311. }
  312. prop, err := oleutil.GetProperty(src, n)
  313. if err != nil {
  314. if !c.AllowMissingFields {
  315. errFieldMismatch = &ErrFieldMismatch{
  316. StructType: of.Type(),
  317. FieldName: n,
  318. Reason: "no such struct field",
  319. }
  320. }
  321. continue
  322. }
  323. defer prop.Clear()
  324. if prop.VT == 0x1 { //VT_NULL
  325. continue
  326. }
  327. switch val := prop.Value().(type) {
  328. case int8, int16, int32, int64, int:
  329. v := reflect.ValueOf(val).Int()
  330. switch f.Kind() {
  331. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  332. f.SetInt(v)
  333. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  334. f.SetUint(uint64(v))
  335. default:
  336. return &ErrFieldMismatch{
  337. StructType: of.Type(),
  338. FieldName: n,
  339. Reason: "not an integer class",
  340. }
  341. }
  342. case uint8, uint16, uint32, uint64:
  343. v := reflect.ValueOf(val).Uint()
  344. switch f.Kind() {
  345. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  346. f.SetInt(int64(v))
  347. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  348. f.SetUint(v)
  349. default:
  350. return &ErrFieldMismatch{
  351. StructType: of.Type(),
  352. FieldName: n,
  353. Reason: "not an integer class",
  354. }
  355. }
  356. case string:
  357. switch f.Kind() {
  358. case reflect.String:
  359. f.SetString(val)
  360. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  361. iv, err := strconv.ParseInt(val, 10, 64)
  362. if err != nil {
  363. return err
  364. }
  365. f.SetInt(iv)
  366. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  367. uv, err := strconv.ParseUint(val, 10, 64)
  368. if err != nil {
  369. return err
  370. }
  371. f.SetUint(uv)
  372. case reflect.Struct:
  373. switch f.Type() {
  374. case timeType:
  375. if len(val) == 25 {
  376. mins, err := strconv.Atoi(val[22:])
  377. if err != nil {
  378. return err
  379. }
  380. val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
  381. }
  382. t, err := time.Parse("20060102150405.000000-0700", val)
  383. if err != nil {
  384. return err
  385. }
  386. f.Set(reflect.ValueOf(t))
  387. }
  388. }
  389. case bool:
  390. switch f.Kind() {
  391. case reflect.Bool:
  392. f.SetBool(val)
  393. default:
  394. return &ErrFieldMismatch{
  395. StructType: of.Type(),
  396. FieldName: n,
  397. Reason: "not a bool",
  398. }
  399. }
  400. case float32:
  401. switch f.Kind() {
  402. case reflect.Float32:
  403. f.SetFloat(float64(val))
  404. default:
  405. return &ErrFieldMismatch{
  406. StructType: of.Type(),
  407. FieldName: n,
  408. Reason: "not a Float32",
  409. }
  410. }
  411. default:
  412. if f.Kind() == reflect.Slice {
  413. switch f.Type().Elem().Kind() {
  414. case reflect.String:
  415. safeArray := prop.ToArray()
  416. if safeArray != nil {
  417. arr := safeArray.ToValueArray()
  418. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  419. for i, v := range arr {
  420. s := fArr.Index(i)
  421. s.SetString(v.(string))
  422. }
  423. f.Set(fArr)
  424. }
  425. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
  426. safeArray := prop.ToArray()
  427. if safeArray != nil {
  428. arr := safeArray.ToValueArray()
  429. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  430. for i, v := range arr {
  431. s := fArr.Index(i)
  432. s.SetUint(reflect.ValueOf(v).Uint())
  433. }
  434. f.Set(fArr)
  435. }
  436. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
  437. safeArray := prop.ToArray()
  438. if safeArray != nil {
  439. arr := safeArray.ToValueArray()
  440. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  441. for i, v := range arr {
  442. s := fArr.Index(i)
  443. s.SetInt(reflect.ValueOf(v).Int())
  444. }
  445. f.Set(fArr)
  446. }
  447. default:
  448. return &ErrFieldMismatch{
  449. StructType: of.Type(),
  450. FieldName: n,
  451. Reason: fmt.Sprintf("unsupported slice type (%T)", val),
  452. }
  453. }
  454. } else {
  455. typeof := reflect.TypeOf(val)
  456. if typeof == nil && (isPtr || c.NonePtrZero) {
  457. if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
  458. of.Set(reflect.Zero(of.Type()))
  459. }
  460. break
  461. }
  462. return &ErrFieldMismatch{
  463. StructType: of.Type(),
  464. FieldName: n,
  465. Reason: fmt.Sprintf("unsupported type (%T)", val),
  466. }
  467. }
  468. }
  469. }
  470. return errFieldMismatch
  471. }
  472. type multiArgType int
  473. const (
  474. multiArgTypeInvalid multiArgType = iota
  475. multiArgTypeStruct
  476. multiArgTypeStructPtr
  477. )
  478. // checkMultiArg checks that v has type []S, []*S for some struct type S.
  479. //
  480. // It returns what category the slice's elements are, and the reflect.Type
  481. // that represents S.
  482. func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
  483. if v.Kind() != reflect.Slice {
  484. return multiArgTypeInvalid, nil
  485. }
  486. elemType = v.Type().Elem()
  487. switch elemType.Kind() {
  488. case reflect.Struct:
  489. return multiArgTypeStruct, elemType
  490. case reflect.Ptr:
  491. elemType = elemType.Elem()
  492. if elemType.Kind() == reflect.Struct {
  493. return multiArgTypeStructPtr, elemType
  494. }
  495. }
  496. return multiArgTypeInvalid, nil
  497. }
  498. func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
  499. v, err := oleutil.GetProperty(item, prop)
  500. if err != nil {
  501. return 0, err
  502. }
  503. defer v.Clear()
  504. i := int64(v.Val)
  505. return i, nil
  506. }
  507. // CreateQuery returns a WQL query string that queries all columns of src. where
  508. // is an optional string that is appended to the query, to be used with WHERE
  509. // clauses. In such a case, the "WHERE" string should appear at the beginning.
  510. // The wmi class is obtained by the name of the type. You can pass a optional
  511. // class throught the variadic class parameter which is useful for anonymous
  512. // structs.
  513. func CreateQuery(src interface{}, where string, class ...string) string {
  514. var b bytes.Buffer
  515. b.WriteString("SELECT ")
  516. s := reflect.Indirect(reflect.ValueOf(src))
  517. t := s.Type()
  518. if s.Kind() == reflect.Slice {
  519. t = t.Elem()
  520. }
  521. if t.Kind() != reflect.Struct {
  522. return ""
  523. }
  524. var fields []string
  525. for i := 0; i < t.NumField(); i++ {
  526. fields = append(fields, t.Field(i).Name)
  527. }
  528. b.WriteString(strings.Join(fields, ", "))
  529. b.WriteString(" FROM ")
  530. if len(class) > 0 {
  531. b.WriteString(class[0])
  532. } else {
  533. b.WriteString(t.Name())
  534. }
  535. b.WriteString(" " + where)
  536. return b.String()
  537. }